diff --git a/app/assets/javascripts/discourse/app/services/presence.js b/app/assets/javascripts/discourse/app/services/presence.js new file mode 100644 index 00000000000..325fc20ea50 --- /dev/null +++ b/app/assets/javascripts/discourse/app/services/presence.js @@ -0,0 +1,470 @@ +import Service from "@ember/service"; +import EmberObject, { computed, defineProperty } from "@ember/object"; +import { readOnly } from "@ember/object/computed"; +import { ajax } from "discourse/lib/ajax"; +import { cancel, debounce, later, throttle } from "@ember/runloop"; +import Session from "discourse/models/session"; +import { Promise } from "rsvp"; +import { isTesting } from "discourse-common/config/environment"; +import User from "discourse/models/user"; + +const PRESENCE_INTERVAL_S = 30; +const PRESENCE_DEBOUNCE_MS = isTesting() ? 0 : 500; +const PRESENCE_THROTTLE_MS = isTesting() ? 0 : 5000; + +function createPromiseProxy() { + const promiseProxy = {}; + promiseProxy.promise = new Promise((resolve, reject) => { + promiseProxy.resolve = resolve; + promiseProxy.reject = reject; + }); + return promiseProxy; +} + +export class PresenceChannelNotFound extends Error {} + +// Instances of this class are handed out to consumers. They act as +// convenient proxies to the PresenceService and PresenceServiceState +class PresenceChannel extends EmberObject { + init({ name, presenceService }) { + super.init(...arguments); + this.name = name; + this.presenceService = presenceService; + defineProperty( + this, + "_presenceState", + readOnly(`presenceService._presenceChannelStates.${name}`) + ); + + this.set("present", false); + this.set("subscribed", false); + } + + // Mark the current user as 'present' in this channel + async enter() { + await this.presenceService._enter(this); + this.set("present", true); + } + + // Mark the current user as leaving this channel + async leave() { + await this.presenceService._leave(this); + this.set("present", false); + } + + async subscribe(initialData = null) { + if (this.subscribed) { + return; + } + await this.presenceService._subscribe(this, initialData); + this.set("subscribed", true); + } + + async unsubscribe() { + if (!this.subscribed) { + return; + } + await this.presenceService._unsubscribe(this); + this.set("subscribed", false); + } + + @computed("_presenceState.users", "subscribed") + get users() { + if (!this.subscribed) { + return; + } + return this._presenceState.users; + } + + @computed("_presenceState.count", "subscribed") + get count() { + if (!this.subscribed) { + return; + } + return this._presenceState.count; + } + + @computed("_presenceState.count", "subscribed") + get countOnly() { + if (!this.subscribed) { + return; + } + return this._presenceState.countOnly; + } +} + +class PresenceChannelState extends EmberObject { + init({ name, presenceService }) { + super.init(...arguments); + this.name = name; + this.set("users", null); + this.set("count", null); + this.set("countOnly", null); + this.presenceService = presenceService; + } + + // Is this PresenceChannel object currently subscribed to updates + // from the server. + @computed("_subscribedCallback") + get subscribed() { + return !!this._subscribedCallback; + } + + // Subscribe to server-side updates about the channel + // Ideally, pass an initialData object with serialized PresenceChannel::State + // data from the server (serialized via PresenceChannelStateSerializer). + // + // If initialData is not supplied, an AJAX request will be made for the information. + async subscribe(initialData = null) { + if (this.subscribed) { + return; + } + + if (!initialData) { + try { + initialData = await ajax("/presence/get", { + data: { + channel: this.name, + }, + }); + } catch (e) { + if (e.jqXHR?.status === 404) { + throw new PresenceChannelNotFound( + `PresenceChannel '${this.name}' not found` + ); + } else { + throw e; + } + } + } + + this.set("count", initialData.count); + if (initialData.users) { + this.set("users", initialData.users); + this.set("countOnly", false); + } else { + this.set("users", null); + this.set("countOnly", true); + } + + this.lastSeenId = initialData.last_message_id; + + let callback = (data, global_id, message_id) => { + this._processMessage(data, global_id, message_id); + }; + this.presenceService.messageBus.subscribe( + `/presence${this.name}`, + callback, + this.lastSeenId + ); + + this.set("_subscribedCallback", callback); + } + + // Stop subscribing to updates from the server about this channel + unsubscribe() { + if (this.subscribed) { + this.presenceService.messageBus.unsubscribe( + `/presence${this.name}`, + this._subscribedCallback + ); + this.set("_subscribedCallback", null); + this.set("users", null); + this.set("count", null); + } + } + + async _resubscribe() { + this.unsubscribe(); + // Stored at object level for tests to hook in + this._resubscribePromise = this.subscribe(); + await this._resubscribePromise; + delete this._resubscribePromise; + } + + async _processMessage(data, global_id, message_id) { + if (message_id !== this.lastSeenId + 1) { + // eslint-disable-next-line no-console + console.log( + `PresenceChannel '${ + this.name + }' dropped message (received ${message_id}, expecting ${ + this.lastSeenId + 1 + }), resyncing...` + ); + + await this._resubscribe(); + return; + } else { + this.lastSeenId = message_id; + } + + if (this.countOnly && data.count_delta !== undefined) { + this.set("count", this.count + data.count_delta); + } else if ( + !this.countOnly && + (data.entering_users || data.leaving_user_ids) + ) { + if (data.entering_users) { + const users = data.entering_users.map((u) => User.create(u)); + this.users.addObjects(users); + } + if (data.leaving_user_ids) { + const leavingIds = new Set(data.leaving_user_ids); + const toRemove = this.users.filter((u) => leavingIds.has(u.id)); + this.users.removeObjects(toRemove); + } + this.set("count", this.users.length); + } else { + // Unexpected message + await this._resubscribe(); + return; + } + } +} + +export default class PresenceService extends Service { + init() { + super.init(...arguments); + this._presentChannels = new Set(); + this._queuedEvents = []; + this._presenceChannelStates = EmberObject.create(); + this._presentProxies = {}; + this._subscribedProxies = {}; + window.addEventListener("beforeunload", () => { + this._beaconLeaveAll(); + }); + } + + // Get a PresenceChannel object representing a single channel + getChannel(channelName) { + return PresenceChannel.create({ + name: channelName, + presenceService: this, + }); + } + + _addPresent(channelProxy) { + let present = this._presentProxies[channelProxy.name]; + if (!present) { + present = this._presentProxies[channelProxy.name] = new Set(); + } + present.add(channelProxy); + return present.size; + } + + _removePresent(channelProxy) { + let present = this._presentProxies[channelProxy.name]; + present?.delete(channelProxy); + return present?.size || 0; + } + + _addSubscribed(channelProxy) { + let subscribed = this._subscribedProxies[channelProxy.name]; + if (!subscribed) { + subscribed = this._subscribedProxies[channelProxy.name] = new Set(); + } + subscribed.add(channelProxy); + return subscribed.size; + } + + _removeSubscribed(channelProxy) { + let subscribed = this._subscribedProxies[channelProxy.name]; + subscribed?.delete(channelProxy); + return subscribed?.size || 0; + } + + async _enter(channelProxy) { + if (!this.currentUser) { + throw "Must be logged in to enter presence channel"; + } + + this._addPresent(channelProxy); + + const channelName = channelProxy.name; + if (this._presentChannels.has(channelName)) { + return; + } + + const promiseProxy = createPromiseProxy(); + + this._presentChannels.add(channelName); + this._queuedEvents.push({ + channel: channelName, + type: "enter", + promiseProxy: promiseProxy, + }); + + this._scheduleNextUpdate(); + + await promiseProxy.promise; + } + + async _leave(channelProxy) { + if (!this.currentUser) { + throw "Must be logged in to leave presence channel"; + } + + const presentCount = this._removePresent(channelProxy); + if (presentCount > 0) { + return; + } + + const channelName = channelProxy.name; + if (!this._presentChannels.has(channelName)) { + return; + } + + const promiseProxy = createPromiseProxy(); + + this._presentChannels.delete(channelName); + this._queuedEvents.push({ + channel: channelName, + type: "leave", + promiseProxy: promiseProxy, + }); + + this._scheduleNextUpdate(); + + await promiseProxy.promise; + } + + async _subscribe(channelProxy, initialData = null) { + this._addSubscribed(channelProxy); + const channelName = channelProxy.name; + let state = this._presenceChannelStates[channelName]; + if (!state) { + state = PresenceChannelState.create({ + name: channelName, + presenceService: this, + }); + this._presenceChannelStates.set(channelName, state); + await state.subscribe(initialData); + } + } + + _unsubscribe(channelProxy) { + const subscribedCount = this._removeSubscribed(channelProxy); + if (subscribedCount === 0) { + const channelName = channelProxy.name; + this._presenceChannelStates[channelName].unsubscribe(); + this._presenceChannelStates.set(channelName, undefined); + } + } + + _beaconLeaveAll() { + if (isTesting()) { + return; + } + this._dedupQueue(); + const channelsToLeave = this._queuedEvents + .filter((e) => e.type === "leave") + .map((e) => e.channel); + + const data = new FormData(); + data.append("client_id", this.messageBus.clientId); + this._presentChannels.forEach((ch) => data.append("leave_channels[]", ch)); + channelsToLeave.forEach((ch) => data.append("leave_channels[]", ch)); + + data.append("authenticity_token", Session.currentProp("csrfToken")); + navigator.sendBeacon("/presence/update", data); + } + + _dedupQueue() { + const deduplicated = {}; + this._queuedEvents.forEach((e) => { + if (deduplicated[e.channel]) { + deduplicated[e.channel].promiseProxy.resolve(e.promiseProxy.promise); + } + deduplicated[e.channel] = e; + }); + this._queuedEvents = Object.values(deduplicated); + } + + async _updateServer() { + this._lastUpdate = new Date(); + this._updateRunning = true; + + this._cancelTimer(); + + this._dedupQueue(); + const queue = this._queuedEvents; + this._queuedEvents = []; + + try { + const channelsToLeave = queue + .filter((e) => e.type === "leave") + .map((e) => e.channel); + + const response = await ajax("/presence/update", { + data: { + client_id: this.messageBus.clientId, + present_channels: [...this._presentChannels], + leave_channels: channelsToLeave, + }, + type: "POST", + }); + + queue.forEach((e) => { + if (response[e.channel] === false) { + e.promiseProxy.reject( + new PresenceChannelNotFound( + `PresenceChannel '${e.channel}' not found` + ) + ); + } else { + e.promiseProxy.resolve(); + } + }); + } catch (e) { + // Updating server failed. Put the failed events + // back in the queue for next time + this._queuedEvents.unshift(...queue); + if (e.jqXHR?.status === 429) { + // Rate limited. No need to raise, we'll try again later + } else { + throw e; + } + } finally { + this._updateRunning = false; + this._scheduleNextUpdate(); + } + } + + // `throttle` only allows triggering on the first **or** the last event + // in a sequence of calls. We want both. We want the first event, to make + // things very responsive. Then if things are happening too frequently, we + // drop back to the last event via the regular throttle function. + _throttledUpdateServer() { + if ( + !this._lastUpdate || + new Date() - this._lastUpdate > PRESENCE_THROTTLE_MS + ) { + this._updateServer(); + } else { + throttle(this, this._updateServer, PRESENCE_THROTTLE_MS, false); + } + } + + _cancelTimer() { + if (this._nextUpdateTimer) { + cancel(this._nextUpdateTimer); + this._nextUpdateTimer = null; + } + } + + _scheduleNextUpdate() { + if (this._updateRunning) { + return; + } else if (this._queuedEvents.length > 0) { + this._cancelTimer(); + debounce(this, this._throttledUpdateServer, PRESENCE_DEBOUNCE_MS); + } else if (!this._nextUpdateTimer && !isTesting()) { + this._nextUpdateTimer = later( + this, + this._throttledUpdateServer, + PRESENCE_INTERVAL_S * 1000 + ); + } + } +} diff --git a/app/assets/javascripts/discourse/tests/unit/services/presence-test.js b/app/assets/javascripts/discourse/tests/unit/services/presence-test.js new file mode 100644 index 00000000000..091951b5d0c --- /dev/null +++ b/app/assets/javascripts/discourse/tests/unit/services/presence-test.js @@ -0,0 +1,313 @@ +import { + acceptance, + publishToMessageBus, +} from "discourse/tests/helpers/qunit-helpers"; +import { test } from "qunit"; +import { PresenceChannelNotFound } from "discourse/services/presence"; + +function usersFixture() { + return [ + { + id: 1, + username: "bruce0", + name: "Bruce Wayne", + avatar_template: "/letter_avatar_proxy/v4/letter/b/90ced4/{size}.png", + }, + { + id: 2, + username: "bruce1", + name: "Bruce Wayne", + avatar_template: "/letter_avatar_proxy/v4/letter/b/9de053/{size}.png", + }, + { + id: 3, + username: "bruce2", + name: "Bruce Wayne", + avatar_template: "/letter_avatar_proxy/v4/letter/b/35a633/{size}.png", + }, + ]; +} +acceptance("Presence - Subscribing", function (needs) { + needs.pretender((server, helper) => { + server.get("/presence/get", (request) => { + if (request.queryParams.channel?.startsWith("/test/")) { + return helper.response({ + count: 3, + last_message_id: 1, + users: usersFixture(), + }); + } else if (request.queryParams.channel?.startsWith("/countonly/")) { + return helper.response({ + count: 3, + last_message_id: 1, + }); + } + + return helper.response(404, {}); + }); + }); + + test("subscribing and receiving updates", async function (assert) { + let presenceService = this.container.lookup("service:presence"); + let channel = presenceService.getChannel("/test/ch1"); + assert.equal(channel.name, "/test/ch1"); + + await channel.subscribe({ + users: usersFixture(), + last_message_id: 1, + }); + + assert.equal(channel.users.length, 3, "it starts with three users"); + + publishToMessageBus( + "/presence/test/ch1", + { + leaving_user_ids: [1], + }, + 0, + 2 + ); + + assert.equal(channel.users.length, 2, "one user is removed"); + + publishToMessageBus( + "/presence/test/ch1", + { + entering_users: [usersFixture()[0]], + }, + 0, + 3 + ); + + assert.equal(channel.users.length, 3, "one user is added"); + }); + + test("fetches data when no initial state", async function (assert) { + let presenceService = this.container.lookup("service:presence"); + let channel = presenceService.getChannel("/test/ch1"); + + await channel.subscribe(); + + assert.equal(channel.users.length, 3, "loads initial state"); + + publishToMessageBus( + "/presence/test/ch1", + { + leaving_user_ids: [1], + }, + 0, + 2 + ); + + assert.equal( + channel.users.length, + 2, + "updates following messagebus message" + ); + + publishToMessageBus( + "/presence/test/ch1", + { + leaving_user_ids: [2], + }, + 0, + 99 + ); + + await channel._presenceState._resubscribePromise; + + assert.equal( + channel.users.length, + 3, + "detects missed messagebus message, fetches data from server" + ); + }); + + test("raises error when subscribing to nonexistent channel", async function (assert) { + let presenceService = this.container.lookup("service:presence"); + let channel = presenceService.getChannel("/nonexistent/ch1"); + + assert.rejects( + channel.subscribe(), + PresenceChannelNotFound, + "raises not found" + ); + }); + + test("can subscribe to count_only channel", async function (assert) { + let presenceService = this.container.lookup("service:presence"); + let channel = presenceService.getChannel("/countonly/ch1"); + + await channel.subscribe(); + + assert.equal(channel.count, 3, "has the correct count"); + assert.equal(channel.countOnly, true, "identifies as countOnly"); + assert.equal(channel.users, null, "has null users list"); + + publishToMessageBus( + "/presence/countonly/ch1", + { + count_delta: 1, + }, + 0, + 2 + ); + + assert.equal(channel.count, 4, "updates the count via messagebus"); + + publishToMessageBus( + "/presence/countonly/ch1", + { + leaving_user_ids: [2], + }, + 0, + 3 + ); + + await channel._presenceState._resubscribePromise; + + assert.equal( + channel.count, + 3, + "resubscribes when receiving a non-count-only message" + ); + }); + + test("can share data between multiple PresenceChannel objects", async function (assert) { + let presenceService = this.container.lookup("service:presence"); + let channel = presenceService.getChannel("/test/ch1"); + let channelDup = presenceService.getChannel("/test/ch1"); + + await channel.subscribe(); + assert.equal(channel.subscribed, true, "channel is subscribed"); + assert.equal(channel.count, 3, "channel has the correct count"); + assert.equal(channel.users.length, 3, "channel has users"); + + assert.equal(channelDup.subscribed, false, "channelDup is not subscribed"); + assert.equal(channelDup.count, null, "channelDup has no count"); + assert.equal(channelDup.users, null, "channelDup has users"); + + await channelDup.subscribe(); + assert.equal(channelDup.subscribed, true, "channelDup can subscribe"); + assert.ok( + channelDup._presenceState, + "channelDup has a valid internal state" + ); + assert.equal( + channelDup._presenceState, + channel._presenceState, + "internal state is shared" + ); + + await channel.unsubscribe(); + assert.equal(channel.subscribed, false, "channel can unsubscribe"); + assert.equal( + channelDup._presenceState, + channel._presenceState, + "state is maintained" + ); + + await channelDup.unsubscribe(); + assert.equal(channel.subscribed, false, "channelDup can unsubscribe"); + assert.equal(channelDup._presenceState, undefined, "state is cleared"); + }); +}); + +acceptance("Presence - Entering and Leaving", function (needs) { + needs.user(); + + const requests = []; + needs.hooks.afterEach(() => requests.clear()); + needs.pretender((server, helper) => { + server.post("/presence/update", (request) => { + const body = new URLSearchParams(request.requestBody); + requests.push(body); + + const response = {}; + const channelsRequested = body.getAll("present_channels[]"); + channelsRequested.forEach((c) => { + if (c.startsWith("/test/")) { + response[c] = true; + } else { + response[c] = false; + } + }); + + return helper.response(response); + }); + }); + + test("can join and leave channels", async function (assert) { + const presenceService = this.container.lookup("service:presence"); + const channel = presenceService.getChannel("/test/ch1"); + + await channel.enter(); + assert.equal(requests.length, 1, "updated the server for enter"); + let presentChannels = requests.pop().getAll("present_channels[]"); + assert.deepEqual( + presentChannels, + ["/test/ch1"], + "included the correct present channel" + ); + + await channel.leave(); + assert.equal(requests.length, 1, "updated the server for leave"); + const request = requests.pop(); + presentChannels = request.getAll("present_channels[]"); + const leaveChannels = request.getAll("leave_channels[]"); + assert.deepEqual(presentChannels, [], "included no present channels"); + assert.deepEqual( + leaveChannels, + ["/test/ch1"], + "included the correct leave channel" + ); + }); + + test("raises an error when entering a non-existant channel", async function (assert) { + const presenceService = this.container.lookup("service:presence"); + const channel = presenceService.getChannel("/blah/doesnotexist"); + await assert.rejects( + channel.enter(), + PresenceChannelNotFound, + "raises a not found error" + ); + }); + + test("deduplicates calls from multiple PresenceChannel instances", async function (assert) { + const presenceService = this.container.lookup("service:presence"); + const channel = presenceService.getChannel("/test/ch1"); + const channelDup = presenceService.getChannel("/test/ch1"); + + await channel.enter(); + assert.equal(channel.present, true, "channel is present"); + assert.equal(channelDup.present, false, "channelDup is absent"); + assert.ok( + presenceService._presentChannels.has("/test/ch1"), + "service shows present" + ); + + await channelDup.enter(); + assert.equal(channel.present, true, "channel is present"); + assert.equal(channelDup.present, true, "channelDup is present"); + assert.ok( + presenceService._presentChannels.has("/test/ch1"), + "service shows present" + ); + + await channel.leave(); + assert.equal(channel.present, false, "channel is absent"); + assert.equal(channelDup.present, true, "channelDup is present"); + assert.ok( + presenceService._presentChannels.has("/test/ch1"), + "service shows present" + ); + + await channelDup.leave(); + assert.equal(channel.present, false, "channel is absent"); + assert.equal(channel.present, false, "channelDup is absent"); + assert.notOk( + presenceService._presentChannels.has("/test/ch1"), + "service shows absent" + ); + }); +}); diff --git a/app/controllers/presence_controller.rb b/app/controllers/presence_controller.rb new file mode 100644 index 00000000000..8fa0b137420 --- /dev/null +++ b/app/controllers/presence_controller.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +class PresenceController < ApplicationController + skip_before_action :check_xhr + before_action :ensure_logged_in, only: [:update] + + def get + name = params.require(:channel) + + begin + channel = PresenceChannel.new(name) + rescue PresenceChannel::NotFound + raise Discourse::NotFound + end + + if !channel.can_view?(user_id: current_user&.id) + # Do not reveal existence of channel + raise Discourse::NotFound + end + + state = channel.state + render json: state, serializer: PresenceChannelStateSerializer, root: nil + end + + def update + client_id = params[:client_id] + raise Discourse::InvalidParameters.new(:client_id) if !client_id.is_a?(String) || client_id.blank? + + # JS client is designed to throttle to one request every 5 seconds + RateLimiter.new(nil, "update-presence-#{current_user.id}-#{client_id}}", 3, 10.seconds).performed! + + present_channels = params[:present_channels] + if present_channels && !(present_channels.is_a?(Array) && present_channels.all? { |c| c.is_a? String }) + raise Discourse::InvalidParameters.new(:present_channels) + end + + leave_channels = params[:leave_channels] + if leave_channels && !(leave_channels.is_a?(Array) && leave_channels.all? { |c| c.is_a? String }) + raise Discourse::InvalidParameters.new(:leave_channels) + end + + if present_channels && present_channels.length > 50 + raise Discourse::InvalidParameters.new("Too many present_channels") + end + + response = {} + + present_channels&.each do |name| + PresenceChannel.new(name).present(user_id: current_user&.id, client_id: params[:client_id]) + response[name] = true + rescue PresenceChannel::NotFound, PresenceChannel::InvalidAccess + response[name] = false + end + + leave_channels&.each do |name| + PresenceChannel.new(name).leave(user_id: current_user&.id, client_id: params[:client_id]) + rescue PresenceChannel::NotFound + # Do nothing. Don't reveal that this channel doesn't exist + end + + render json: response + end + +end diff --git a/app/jobs/scheduled/presence_channel_auto_leave.rb b/app/jobs/scheduled/presence_channel_auto_leave.rb new file mode 100644 index 00000000000..b9ecb293423 --- /dev/null +++ b/app/jobs/scheduled/presence_channel_auto_leave.rb @@ -0,0 +1,11 @@ +# frozen_string_literal: true + +module Jobs + class PresenceChannelAutoLeave < ::Jobs::Scheduled + every PresenceChannel::DEFAULT_TIMEOUT.seconds + + def execute(args) + PresenceChannel.auto_leave_all + end + end +end diff --git a/app/serializers/presence_channel_state_serializer.rb b/app/serializers/presence_channel_state_serializer.rb new file mode 100644 index 00000000000..904ca0b7535 --- /dev/null +++ b/app/serializers/presence_channel_state_serializer.rb @@ -0,0 +1,14 @@ +# frozen_string_literal: true + +class PresenceChannelStateSerializer < ApplicationSerializer + attributes :count, :last_message_id + has_many :users, serializer: BasicUserSerializer, embed: :objects + + def last_message_id + object.message_bus_last_id + end + + def include_users? + !users.nil? + end +end diff --git a/config/routes.rb b/config/routes.rb index 8a237167203..63b1498aed4 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -1006,6 +1006,9 @@ Discourse::Application.routes.draw do post "/do-not-disturb" => "do_not_disturb#create" delete "/do-not-disturb" => "do_not_disturb#destroy" + post "/presence/update" => "presence#update" + get "/presence/get" => "presence#get" + get "*url", to: 'permalinks#show', constraints: PermalinkConstraint.new end end diff --git a/lib/discourse_plugin_registry.rb b/lib/discourse_plugin_registry.rb index 097db834c55..de86a251bbc 100644 --- a/lib/discourse_plugin_registry.rb +++ b/lib/discourse_plugin_registry.rb @@ -86,6 +86,8 @@ class DiscoursePluginRegistry define_filtered_register :permitted_bulk_action_parameters define_filtered_register :reviewable_params + define_filtered_register :presence_channel_prefixes + def self.register_auth_provider(auth_provider) self.auth_providers << auth_provider end diff --git a/lib/plugin/instance.rb b/lib/plugin/instance.rb index 471f5c34e93..6e9f59c906b 100644 --- a/lib/plugin/instance.rb +++ b/lib/plugin/instance.rb @@ -923,6 +923,12 @@ class Plugin::Instance }, self) end + # Register a new PresenceChannel prefix. See {PresenceChannel.register_prefix} + # for usage instructions + def register_presence_channel_prefix(prefix, &block) + DiscoursePluginRegistry.register_presence_channel_prefix([prefix, block], self) + end + protected def self.js_path diff --git a/lib/presence_channel.rb b/lib/presence_channel.rb new file mode 100644 index 00000000000..14bd4417f01 --- /dev/null +++ b/lib/presence_channel.rb @@ -0,0 +1,627 @@ +# frozen_string_literal: true + +# The server-side implementation of PresenceChannels. See also {PresenceController} +# and +app/assets/javascripts/discourse/app/services/presence.js+ +class PresenceChannel + class NotFound < StandardError; end + class InvalidAccess < StandardError; end + class ConfigNotLoaded < StandardError; end + class InvalidConfig < StandardError; end + + class State + include ActiveModel::Serialization + + attr_reader :message_bus_last_id + attr_reader :user_ids + attr_reader :count + + def initialize(message_bus_last_id: , user_ids: nil, count: nil) + raise "user_ids or count required" if user_ids.nil? && count.nil? + @message_bus_last_id = message_bus_last_id + @user_ids = user_ids + @count = count || user_ids.count + end + + def users + return nil if user_ids.nil? + User.where(id: user_ids) + end + end + + # Class for managing config of PresenceChannel + # Three parameters can be provided on initialization: + # public: boolean value. If true, channel information is visible to all users (default false) + # allowed_user_ids: array of user_ids that can view, and become present in, the channel (default []) + # allowed_group_ids: array of group_ids that can view, and become present in, the channel (default []) + # count_only: boolean. If true, user identities are never revealed to clients. (default []) + class Config + NOT_FOUND ||= "notfound" + attr_accessor :public, :allowed_user_ids, :allowed_group_ids, :count_only + + def initialize(public: false, allowed_user_ids: nil, allowed_group_ids: nil, count_only: false) + @public = public + @allowed_user_ids = allowed_user_ids + @allowed_group_ids = allowed_group_ids + @count_only = count_only + end + + def self.from_json(json) + data = JSON.parse(json, symbolize_names: true) + data = {} if !data.is_a? Hash + new(**data.slice(:public, :allowed_user_ids, :allowed_group_ids, :count_only)) + end + + def to_json + data = { public: public } + data[:allowed_user_ids] = allowed_user_ids if allowed_user_ids + data[:allowed_group_ids] = allowed_group_ids if allowed_group_ids + data[:count_only] = count_only if count_only + data.to_json + end + end + + DEFAULT_TIMEOUT ||= 60 + CONFIG_CACHE_SECONDS ||= 120 + GC_SECONDS ||= 24.hours.to_i + MUTEX_TIMEOUT_SECONDS ||= 10 + MUTEX_LOCKED_ERROR ||= "PresenceChannel mutex is locked" + + @@configuration_blocks ||= {} + + attr_reader :name, :timeout, :message_bus_channel_name, :config + + def initialize(name, raise_not_found: true) + @name = name + @timeout = DEFAULT_TIMEOUT + @message_bus_channel_name = "/presence#{name}" + + begin + @config = fetch_config + rescue PresenceChannel::NotFound + raise if raise_not_found + @config = Config.new + end + end + + # Is this user allowed to view this channel? + # Pass `nil` for anonymous viewers + def can_view?(user_id: nil) + return true if config.public + return true if user_id && config.allowed_user_ids&.include?(user_id) + if user_id && config.allowed_group_ids.present? + user_group_ids = GroupUser.where(user_id: user_id).pluck("group_id") + return true if (user_group_ids & config.allowed_group_ids).present? + end + false + end + + # Is a user allowed to enter this channel? + # Currently equal to the the can_view? permission + def can_enter?(user_id: nil) + return false if user_id.nil? + can_view?(user_id: user_id) + end + + # Mark a user's client as present in this channel. The client_id should be unique per + # browser tab. This method should be called repeatedly (at least once every DEFAULT_TIMEOUT) + # while the user is present in the channel. + def present(user_id:, client_id:) + raise PresenceChannel::InvalidAccess if !can_enter?(user_id: user_id) + + mutex_value = SecureRandom.hex + result = retry_on_mutex_error do + PresenceChannel.redis_eval( + :present, + redis_keys, + [name, user_id, client_id, (Time.zone.now + timeout).to_i, mutex_value] + ) + end + + if result == 1 + begin + publish_message(entering_user_ids: [user_id]) + ensure + release_mutex(mutex_value) + end + end + + auto_leave + end + + # Immediately mark a user's client as leaving the channel + def leave(user_id:, client_id:) + mutex_value = SecureRandom.hex + result = retry_on_mutex_error do + PresenceChannel.redis_eval( + :leave, + redis_keys, + [name, user_id, client_id, nil, mutex_value] + ) + end + + if result == 1 + begin + publish_message(leaving_user_ids: [user_id]) + ensure + release_mutex(mutex_value) + end + end + + auto_leave + end + + # Fetch a {PresenceChannel::State} instance representing the current state of this + # + # @param [Boolean] count_only set true to skip fetching the list of user ids from redis + def state(count_only: config.count_only) + auto_leave + + if count_only + last_id, count = retry_on_mutex_error do + PresenceChannel.redis_eval( + :count, + redis_keys, + ) + end + else + last_id, ids = retry_on_mutex_error do + PresenceChannel.redis_eval( + :user_ids, + redis_keys, + ) + end + end + count ||= ids&.count + last_id = nil if last_id == -1 + + if Rails.env.test? && MessageBus.backend == :memory + # Doing it this way is not atomic, but we have no other option when + # messagebus is not using the redis backend + last_id = MessageBus.last_id(message_bus_channel_name) + end + + State.new(message_bus_last_id: last_id, user_ids: ids, count: count) + end + + def user_ids + state.user_ids + end + + def count + state(count_only: true).count + end + + # Automatically expire all users which have not been 'present' for more than +DEFAULT_TIMEOUT+ + def auto_leave + mutex_value = SecureRandom.hex + left_user_ids = retry_on_mutex_error do + PresenceChannel.redis_eval( + :auto_leave, + redis_keys, + [name, Time.zone.now.to_i, mutex_value] + ) + end + + if !left_user_ids.empty? + begin + publish_message(leaving_user_ids: left_user_ids) + ensure + release_mutex(mutex_value) + end + end + end + + # Clear all members of the channel. This is intended for debugging/development only + def clear + PresenceChannel.redis.del(redis_key_zlist) + PresenceChannel.redis.del(redis_key_hash) + PresenceChannel.redis.del(redis_key_config) + PresenceChannel.redis.del(redis_key_mutex) + PresenceChannel.redis.zrem(self.class.redis_key_channel_list, name) + end + + # Designed to be run periodically. Checks the channel list for channels with expired members, + # and runs auto_leave for each eligable channel + def self.auto_leave_all + channels_with_expiring_members = PresenceChannel.redis.zrangebyscore(redis_key_channel_list, '-inf', Time.zone.now.to_i) + channels_with_expiring_members.each do |name| + new(name, raise_not_found: false).auto_leave + end + end + + # Clear all known channels. This is intended for debugging/development only + def self.clear_all! + channels = PresenceChannel.redis.zrangebyscore(redis_key_channel_list, '-inf', '+inf') + channels.each do |name| + new(name, raise_not_found: false).clear + end + + config_cache_keys = PresenceChannel.redis.scan_each(match: Discourse.redis.namespace_key("_presence_*_config")).to_a + PresenceChannel.redis.del(*config_cache_keys) if config_cache_keys.present? + end + + # Shortcut to access a redis client for all PresenceChannel activities. + # PresenceChannel must use the same Redis server as MessageBus, so that + # actions can be applied atomically. For the vast majority of Discourse + # installations, this is the same Redis server as `Discourse.redis`. + def self.redis + if MessageBus.backend == :redis + MessageBus.reliable_pub_sub.send(:pub_redis) # TODO: avoid a private API? + elsif Rails.env.test? + Discourse.redis.without_namespace + else + raise "PresenceChannel is unable to access MessageBus's Redis instance" + end + end + + def self.redis_eval(key, *args) + script_sha1 = LUA_SCRIPTS_SHA1[key] + raise ArgumentError.new("No script for #{key}") if script_sha1.nil? + redis.evalsha script_sha1, *args + rescue ::Redis::CommandError => e + if e.to_s =~ /^NOSCRIPT/ + redis.eval LUA_SCRIPTS[key], *args + else + raise + end + end + + # Register a callback to configure channels with a given prefix + # Prefix must match [a-zA-Z0-9_-]+ + # + # For example, this registration will be used for + # all channels starting /topic-reply/...: + # + # register_prefix("topic-reply") do |channel_name| + # PresenceChannel::Config.new(public: true) + # end + # + # At runtime, the block will be passed a full channel name. If the channel + # should not exist, the block should return `nil`. If the channel should exist, + # the block should return a PresenceChannel::Config object. + # + # Return values may be cached for up to 2 minutes. + # + # Plugins should use the {Plugin::Instance.register_presence_channel_prefix} API instead + def self.register_prefix(prefix, &block) + raise "PresenceChannel prefix #{prefix} must match [a-zA-Z0-9_-]+" unless prefix.match? /[a-zA-Z0-9_-]+/ + raise "PresenceChannel prefix #{prefix} already registered" if @@configuration_blocks&.[](prefix) + @@configuration_blocks[prefix] = block + end + + # For use in a test environment only + def self.unregister_prefix(prefix) + raise "Only allowed in test environment" if !Rails.env.test? + @@configuration_blocks&.delete(prefix) + end + + private + + def fetch_config + cached_config = PresenceChannel.redis.get(redis_key_config) + + if cached_config == Config::NOT_FOUND + raise PresenceChannel::NotFound + elsif cached_config + Config.from_json(cached_config) + else + prefix = name[/\/([a-zA-Z0-9_-]+)\/.*/, 1] + raise PresenceChannel::NotFound if prefix.nil? + + config_block = @@configuration_blocks[prefix] + config_block ||= DiscoursePluginRegistry.presence_channel_prefixes.find { |t| t[0] == prefix }&.[](1) + raise PresenceChannel::NotFound if config_block.nil? + + result = config_block.call(name) + to_cache = if result.is_a? Config + result.to_json + elsif result.nil? + Config::NOT_FOUND + else + raise InvalidConfig.new "Expected PresenceChannel::Config or nil. Got a #{result.class.name}" + end + PresenceChannel.redis.set(redis_key_config, to_cache, ex: CONFIG_CACHE_SECONDS) + + raise PresenceChannel::NotFound if result.nil? + result + end + end + + def publish_message(entering_user_ids: nil, leaving_user_ids: nil) + message = {} + if config.count_only + message["count_delta"] = entering_user_ids&.count || 0 + message["count_delta"] -= leaving_user_ids&.count || 0 + return if message["count_delta"] == 0 + else + message["leaving_user_ids"] = leaving_user_ids if leaving_user_ids.present? + if entering_user_ids.present? + users = User.where(id: entering_user_ids) + message["entering_users"] = ActiveModel::ArraySerializer.new(users, each_serializer: BasicUserSerializer) + end + end + + params = {} + + if config.public + # no params required + elsif config.allowed_user_ids || config.allowed_group_ids + params[:user_ids] = config.allowed_user_ids + params[:group_ids] = config.allowed_group_ids + else + # nobody is allowed... don't publish anything + return + end + + MessageBus.publish(message_bus_channel_name, message.as_json, **params) + end + + # Most atomic actions are achieved via lua scripts. However, when a lua action + # will result in publishing a messagebus message, the atomicity is broken. + # + # For example, if one process is handling a 'user enter' event, and another is + # handling a 'user leave' event, we need to make sure the messagebus messages + # are published in the same sequence that the PresenceChannel lua script are run. + # + # The present/leave/auto_leave lua scripts will automatically acquire this mutex + # if needed. If their return value indicates a change has occured, the mutex + # should be released via #release_mutex after the messagebus message has been sent + # + # If they need a change, and the mutex is not available, they will raise an error + # and should be retried periodically + def redis_key_mutex + Discourse.redis.namespace_key("_presence_#{name}_mutex") + end + + def release_mutex(mutex_value) + PresenceChannel.redis_eval( + :release_mutex, + [redis_key_mutex], + [mutex_value] + ) + end + + def retry_on_mutex_error + attempts ||= 0 + yield + rescue ::Redis::CommandError => e + if e.to_s =~ /#{MUTEX_LOCKED_ERROR}/ && attempts < 1000 + attempts += 1 + sleep 0.001 + retry + else + raise + end + end + + # The redis key which MessageBus uses to store the 'last_id' for the channel + # associated with this PresenceChannel. + def message_bus_last_id_key + return "" if Rails.env.test? && MessageBus.backend == :memory + + # TODO: Avoid using private MessageBus methods here + encoded_channel_name = MessageBus.send(:encode_channel_name, message_bus_channel_name) + MessageBus.reliable_pub_sub.send(:backlog_id_key, encoded_channel_name) + end + + def redis_keys + [redis_key_zlist, redis_key_hash, self.class.redis_key_channel_list, message_bus_last_id_key, redis_key_mutex] + end + + # The zlist is a list of client_ids, ranked by their expiration timestamp + # we periodically delete the 'lowest ranked' items in this list based on the `timeout` of the channel + def redis_key_zlist + Discourse.redis.namespace_key("_presence_#{name}_zlist") + end + + # The hash contains a map of user_id => session_count + # when the count for a user reaches 0, the key is deleted + # We use this hash to efficiently count the number of present users + def redis_key_hash + Discourse.redis.namespace_key("_presence_#{name}_hash") + end + + # The hash contains a map of user_id => session_count + # when the count for a user reaches 0, the key is deleted + # We use this hash to efficiently count the number of present users + def redis_key_config + Discourse.redis.namespace_key("_presence_#{name}_config") + end + + # This list contains all active presence channels, ranked with the expiration timestamp of their least-recently-seen client_id + # We periodically check the 'lowest ranked' items in this list based on the `timeout` of the channel + def self.redis_key_channel_list + Discourse.redis.namespace_key("_presence_channels") + end + + COMMON_PRESENT_LEAVE_LUA = <<~LUA + local channel = ARGV[1] + local user_id = ARGV[2] + local client_id = ARGV[3] + local expires = ARGV[4] + local mutex_value = ARGV[5] + + local zlist_key = KEYS[1] + local hash_key = KEYS[2] + local channels_key = KEYS[3] + local message_bus_id_key = KEYS[4] + local mutex_key = KEYS[5] + + local mutex_locked = redis.call('EXISTS', mutex_key) == 1 + + local zlist_elem = tostring(user_id) .. " " .. tostring(client_id) + LUA + + UPDATE_GLOBAL_CHANNELS_LUA = <<~LUA + -- Update the global channels list with the timestamp of the oldest client + local oldest_client = redis.call('ZRANGE', zlist_key, 0, 0, 'WITHSCORES') + if table.getn(oldest_client) > 0 then + local oldest_client_expire_timestamp = oldest_client[2] + redis.call('ZADD', channels_key, tonumber(oldest_client_expire_timestamp), tostring(channel)) + else + -- The channel is now empty, delete from global list + redis.call('ZREM', channels_key, tostring(channel)) + end + LUA + + LUA_SCRIPTS ||= {} + + LUA_SCRIPTS[:present] = <<~LUA + #{COMMON_PRESENT_LEAVE_LUA} + + if mutex_locked then + local mutex_required = redis.call('HGET', hash_key, tostring(user_id)) == false + if mutex_required then + error("#{MUTEX_LOCKED_ERROR}") + end + end + + local added_clients = redis.call('ZADD', zlist_key, expires, zlist_elem) + local added_users = 0 + if tonumber(added_clients) > 0 then + local new_count = redis.call('HINCRBY', hash_key, tostring(user_id), 1) + if new_count == 1 then + added_users = 1 + redis.call('SET', mutex_key, mutex_value, 'EX', #{MUTEX_TIMEOUT_SECONDS}) + end + -- Add the channel to the global channel list. 'LT' means the value will + -- only be set if it's lower than the existing value + redis.call('ZADD', channels_key, "LT", expires, tostring(channel)) + end + + redis.call('EXPIREAT', hash_key, expires + #{GC_SECONDS}) + redis.call('EXPIREAT', zlist_key, expires + #{GC_SECONDS}) + + return added_users + LUA + + LUA_SCRIPTS[:leave] = <<~LUA + #{COMMON_PRESENT_LEAVE_LUA} + + if mutex_locked then + local user_session_count = redis.call('HGET', hash_key, tostring(user_id)) + local mutex_required = user_session_count == 1 and redis.call('ZRANK', zlist_key, zlist_elem) ~= false + if mutex_required then + error("#{MUTEX_LOCKED_ERROR}") + end + end + + -- Remove the user from the channel zlist + local removed_clients = redis.call('ZREM', zlist_key, zlist_elem) + + local removed_users = 0 + if tonumber(removed_clients) > 0 then + #{UPDATE_GLOBAL_CHANNELS_LUA} + + -- Update the user session count in the channel hash + local val = redis.call('HINCRBY', hash_key, user_id, -1) + if val <= 0 then + redis.call('HDEL', hash_key, user_id) + removed_users = 1 + redis.call('SET', mutex_key, mutex_value, 'EX', #{MUTEX_TIMEOUT_SECONDS}) + end + end + + return removed_users + LUA + + LUA_SCRIPTS[:release_mutex] = <<~LUA + local mutex_key = KEYS[1] + local expected_value = ARGV[1] + + if redis.call("GET", mutex_key) == expected_value then + redis.call("DEL", mutex_key) + end + LUA + + LUA_SCRIPTS[:user_ids] = <<~LUA + local zlist_key = KEYS[1] + local hash_key = KEYS[2] + local message_bus_id_key = KEYS[4] + local mutex_key = KEYS[5] + + if redis.call('EXISTS', mutex_key) > 0 then + error('#{MUTEX_LOCKED_ERROR}') + end + + local user_ids = redis.call('HKEYS', hash_key) + table.foreach(user_ids, function(k,v) user_ids[k] = tonumber(v) end) + + local message_bus_id = tonumber(redis.call('GET', message_bus_id_key)) + if message_bus_id == nil then + message_bus_id = -1 + end + + return { message_bus_id, user_ids } + LUA + + LUA_SCRIPTS[:count] = <<~LUA + local zlist_key = KEYS[1] + local hash_key = KEYS[2] + local message_bus_id_key = KEYS[4] + local mutex_key = KEYS[5] + + if redis.call('EXISTS', mutex_key) > 0 then + error('#{MUTEX_LOCKED_ERROR}') + end + + local message_bus_id = tonumber(redis.call('GET', message_bus_id_key)) + if message_bus_id == nil then + message_bus_id = -1 + end + + local count = redis.call('HLEN', hash_key) + + return { message_bus_id, count } + LUA + + LUA_SCRIPTS[:auto_leave] = <<~LUA + local zlist_key = KEYS[1] + local hash_key = KEYS[2] + local channels_key = KEYS[3] + local mutex_key = KEYS[5] + local channel = ARGV[1] + local time = ARGV[2] + local mutex_value = ARGV[3] + + local expire = redis.call('ZRANGE', zlist_key, '-inf', time, 'BYSCORE') + + local has_mutex = false + + local get_mutex = function() + if redis.call('SETNX', mutex_key, mutex_value) == 0 then + error("#{MUTEX_LOCKED_ERROR}") + end + redis.call('EXPIRE', mutex_key, #{MUTEX_TIMEOUT_SECONDS}) + has_mutex = true + end + + local expired_user_ids = {} + + local expireOld = function(k, v) + local user_id = v:match("[^ ]+") + + if (not has_mutex) and (tonumber(redis.call('HGET', hash_key, user_id)) == 1) then + get_mutex() + end + + local val = redis.call('HINCRBY', hash_key, user_id, -1) + if val <= 0 then + table.insert(expired_user_ids, tonumber(user_id)) + redis.call('HDEL', hash_key, user_id) + end + redis.call('ZREM', zlist_key, v) + end + + table.foreach(expire, expireOld) + + #{UPDATE_GLOBAL_CHANNELS_LUA} + + return expired_user_ids + LUA + LUA_SCRIPTS.freeze + + LUA_SCRIPTS_SHA1 = LUA_SCRIPTS.transform_values do |script| + Digest::SHA1.hexdigest(script) + end.freeze +end diff --git a/spec/components/presence_channel_spec.rb b/spec/components/presence_channel_spec.rb new file mode 100644 index 00000000000..1508f790146 --- /dev/null +++ b/spec/components/presence_channel_spec.rb @@ -0,0 +1,286 @@ +# frozen_string_literal: true + +require 'rails_helper' +require 'presence_channel' + +describe PresenceChannel do + fab!(:user) { Fabricate(:user) } + fab!(:group) { Fabricate(:group).tap { |g| g.add(user) } } + fab!(:user2) { Fabricate(:user) } + + before do + PresenceChannel.clear_all! + + secure_user = Fabricate(:user) + secure_group = Fabricate(:group) + PresenceChannel.register_prefix("test") do |channel| + case channel + when /\A\/test\/public\d*\z/ + PresenceChannel::Config.new(public: true) + when "/test/secureuser" + PresenceChannel::Config.new(allowed_user_ids: [ secure_user.id ]) + when "/test/securegroup" + PresenceChannel::Config.new(allowed_group_ids: [ secure_group.id ]) + when "/test/alloweduser" + PresenceChannel::Config.new(allowed_user_ids: [ user.id ]) + when "/test/allowedgroup" + PresenceChannel::Config.new(allowed_group_ids: [ group.id ]) + when "/test/noaccess" + PresenceChannel::Config.new + when "/test/countonly" + PresenceChannel::Config.new(count_only: true, public: true) + else + nil + end + end + end + + after do + PresenceChannel.clear_all! + PresenceChannel.unregister_prefix('test') + end + + it "can perform basic channel functionality" do + channel1 = PresenceChannel.new("/test/public1") + channel2 = PresenceChannel.new("/test/public1") + channel3 = PresenceChannel.new("/test/public1") + + expect(channel3.user_ids).to eq([]) + + channel1.present(user_id: user.id, client_id: 1) + channel2.present(user_id: user.id, client_id: 2) + + expect(channel3.user_ids).to eq([user.id]) + expect(channel3.count).to eq(1) + + channel1.leave(user_id: user.id, client_id: 2) + + expect(channel3.user_ids).to eq([user.id]) + expect(channel3.count).to eq(1) + + channel2.leave(user_id: user.id, client_id: 1) + + expect(channel3.user_ids).to eq([]) + expect(channel3.count).to eq(0) + end + + it "can automatically expire users" do + channel = PresenceChannel.new("/test/public1") + + channel.present(user_id: user.id, client_id: 76) + channel.present(user_id: user.id, client_id: 77) + + expect(channel.count).to eq(1) + + freeze_time Time.zone.now + 1 + PresenceChannel::DEFAULT_TIMEOUT + + expect(channel.count).to eq(0) + end + + it "correctly sends messages to message bus" do + channel = PresenceChannel.new("/test/public1") + + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.present(user_id: user.id, client_id: "a") + end + + data = messages.map(&:data) + expect(data.count).to eq(1) + expect(data[0].keys).to contain_exactly("entering_users") + expect(data[0]["entering_users"].map { |u| u[:id] }).to contain_exactly(user.id) + + freeze_time Time.zone.now + 1 + PresenceChannel::DEFAULT_TIMEOUT + + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.auto_leave + end + + data = messages.map(&:data) + expect(data.count).to eq(1) + expect(data[0].keys).to contain_exactly("leaving_user_ids") + expect(data[0]["leaving_user_ids"]).to contain_exactly(user.id) + end + + it "can track active channels, and auto_leave_all successfully" do + channel1 = PresenceChannel.new("/test/public1") + channel2 = PresenceChannel.new("/test/public2") + + channel1.present(user_id: user.id, client_id: "a") + channel2.present(user_id: user.id, client_id: "a") + + start_time = Time.zone.now + + freeze_time start_time + PresenceChannel::DEFAULT_TIMEOUT / 2 + + channel2.present(user_id: user2.id, client_id: "b") + + freeze_time start_time + PresenceChannel::DEFAULT_TIMEOUT + 1 + + messages = MessageBus.track_publish do + PresenceChannel.auto_leave_all + end + + expect(messages.map { |m| [ m.channel, m.data ] }).to contain_exactly( + ["/presence/test/public1", { "leaving_user_ids" => [user.id] }], + ["/presence/test/public2", { "leaving_user_ids" => [user.id] }] + ) + + expect(channel1.user_ids).to eq([]) + expect(channel2.user_ids).to eq([user2.id]) + end + + it 'only sends one `enter` and `leave` message' do + channel = PresenceChannel.new("/test/public1") + + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.present(user_id: user.id, client_id: "a") + channel.present(user_id: user.id, client_id: "a") + channel.present(user_id: user.id, client_id: "b") + end + + data = messages.map(&:data) + expect(data.count).to eq(1) + expect(data[0].keys).to contain_exactly("entering_users") + expect(data[0]["entering_users"].map { |u| u[:id] }).to contain_exactly(user.id) + + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.leave(user_id: user.id, client_id: "a") + channel.leave(user_id: user.id, client_id: "a") + channel.leave(user_id: user.id, client_id: "b") + end + + data = messages.map(&:data) + expect(data.count).to eq(1) + expect(data[0].keys).to contain_exactly("leaving_user_ids") + expect(data[0]["leaving_user_ids"]).to contain_exactly(user.id) + end + + it "will return the messagebus last_id in the state payload" do + channel = PresenceChannel.new("/test/public1") + + channel.present(user_id: user.id, client_id: "a") + channel.present(user_id: user2.id, client_id: "a") + + state = channel.state + expect(state.user_ids).to contain_exactly(user.id, user2.id) + expect(state.count).to eq(2) + expect(state.message_bus_last_id).to eq(MessageBus.last_id(channel.message_bus_channel_name)) + end + + it "sets an expiry on all channel-specific keys" do + r = Discourse.redis.without_namespace + channel = PresenceChannel.new("/test/public1") + channel.present(user_id: user.id, client_id: "a") + + channels_ttl = r.ttl(PresenceChannel.redis_key_channel_list) + expect(channels_ttl).to eq(-1) # Persistent + + initial_zlist_ttl = r.ttl(channel.send(:redis_key_zlist)) + initial_hash_ttl = r.ttl(channel.send(:redis_key_hash)) + + expect(initial_zlist_ttl).to be_between(PresenceChannel::GC_SECONDS, PresenceChannel::GC_SECONDS + 5.minutes) + expect(initial_hash_ttl).to be_between(PresenceChannel::GC_SECONDS, PresenceChannel::GC_SECONDS + 5.minutes) + + freeze_time 1.minute.from_now + + # PresenceChannel#present is responsible for bumping ttl + channel.present(user_id: user.id, client_id: "a") + + new_zlist_ttl = r.ttl(channel.send(:redis_key_zlist)) + new_hash_ttl = r.ttl(channel.send(:redis_key_hash)) + + expect(new_zlist_ttl).to be > initial_zlist_ttl + expect(new_hash_ttl).to be > initial_hash_ttl + end + + it 'handles security correctly for anon' do + expect(PresenceChannel.new("/test/public1").can_enter?(user_id: nil)).to eq(false) + expect(PresenceChannel.new("/test/secureuser").can_enter?(user_id: nil)).to eq(false) + expect(PresenceChannel.new("/test/securegroup").can_enter?(user_id: nil)).to eq(false) + expect(PresenceChannel.new("/test/noaccess").can_enter?(user_id: nil)).to eq(false) + + expect(PresenceChannel.new("/test/public1").can_view?(user_id: nil)).to eq(true) + expect(PresenceChannel.new("/test/secureuser").can_view?(user_id: nil)).to eq(false) + expect(PresenceChannel.new("/test/securegroup").can_view?(user_id: nil)).to eq(false) + expect(PresenceChannel.new("/test/noaccess").can_view?(user_id: nil)).to eq(false) + end + + it 'handles security correctly for a user' do + expect(PresenceChannel.new("/test/secureuser").can_enter?(user_id: user.id)).to eq(false) + expect(PresenceChannel.new("/test/securegroup").can_enter?(user_id: user.id)).to eq(false) + expect(PresenceChannel.new("/test/alloweduser").can_enter?(user_id: user.id)).to eq(true) + expect(PresenceChannel.new("/test/allowedgroup").can_enter?(user_id: user.id)).to eq(true) + expect(PresenceChannel.new("/test/noaccess").can_enter?(user_id: user.id)).to eq(false) + + expect(PresenceChannel.new("/test/secureuser").can_view?(user_id: user.id)).to eq(false) + expect(PresenceChannel.new("/test/securegroup").can_view?(user_id: user.id)).to eq(false) + expect(PresenceChannel.new("/test/alloweduser").can_view?(user_id: user.id)).to eq(true) + expect(PresenceChannel.new("/test/allowedgroup").can_view?(user_id: user.id)).to eq(true) + expect(PresenceChannel.new("/test/noaccess").can_view?(user_id: user.id)).to eq(false) + end + + it 'publishes messages with appropriate security' do + channel = PresenceChannel.new("/test/alloweduser") + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.present(user_id: user.id, client_id: "a") + end + expect(messages.count).to eq(1) + expect(messages[0].user_ids).to eq([user.id]) + + channel = PresenceChannel.new("/test/allowedgroup") + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.present(user_id: user.id, client_id: "a") + end + expect(messages.count).to eq(1) + expect(messages[0].group_ids).to eq([group.id]) + end + + it 'publishes messages correctly in count_only mode' do + channel = PresenceChannel.new("/test/countonly") + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.present(user_id: user.id, client_id: "a") + end + expect(messages.count).to eq(1) + expect(messages[0].data).to eq({ + "count_delta" => 1 + }) + + messages = MessageBus.track_publish(channel.message_bus_channel_name) do + channel.leave(user_id: user.id, client_id: "a") + end + expect(messages.count).to eq(1) + expect(messages[0].data).to eq({ + "count_delta" => -1 + }) + end + + it "sets a mutex when the change involves publishing messages" do + channel = PresenceChannel.new("/test/public1") + + messages_published = 0 + channel.define_singleton_method(:publish_message) do |*args, **kwargs| + val = PresenceChannel.redis.get(redis_key_mutex) + raise "Mutex was not set" if val.nil? + messages_published += 1 + end + + redis_key_mutex = Discourse.redis.namespace_key("_presence_/test/public1_mutex") + + # Enter and leave + expect(PresenceChannel.redis.get(redis_key_mutex)).to eq(nil) + channel.present(user_id: user.id, client_id: 'a') + expect(PresenceChannel.redis.get(redis_key_mutex)).to eq(nil) + channel.leave(user_id: user.id, client_id: 'a') + expect(PresenceChannel.redis.get(redis_key_mutex)).to eq(nil) + expect(messages_published).to eq(2) + + # Enter and auto_leave + channel.present(user_id: user.id, client_id: 'a') + expect(PresenceChannel.redis.get(redis_key_mutex)).to eq(nil) + freeze_time 1.hour.from_now + channel.auto_leave + expect(PresenceChannel.redis.get(redis_key_mutex)).to eq(nil) + + expect(messages_published).to eq(4) + end +end diff --git a/spec/requests/presence_controller_spec.rb b/spec/requests/presence_controller_spec.rb new file mode 100644 index 00000000000..460c826767a --- /dev/null +++ b/spec/requests/presence_controller_spec.rb @@ -0,0 +1,181 @@ +# frozen_string_literal: true + +require 'rails_helper' + +describe PresenceController do + fab!(:user) { Fabricate(:user) } + fab!(:group) { Fabricate(:group).tap { |g| g.add(user) } } + + let(:ch1) { PresenceChannel.new('/test/public1') } + let(:ch2) { PresenceChannel.new('/test/public2') } + + let(:secure_user_channel) { PresenceChannel.new('/test/secureuser') } + let(:secure_group_channel) { PresenceChannel.new('/test/securegroup') } + let(:allowed_user_channel) { PresenceChannel.new('/test/alloweduser') } + let(:allowed_group_channel) { PresenceChannel.new('/test/allowedgroup') } + let(:count_only_channel) { PresenceChannel.new('/test/countonly') } + + before do + PresenceChannel.clear_all! + + secure_user = Fabricate(:user) + secure_group = Fabricate(:group) + PresenceChannel.register_prefix("test") do |channel| + case channel + when /\A\/test\/public\d*\z/ + PresenceChannel::Config.new(public: true) + when "/test/secureuser" + PresenceChannel::Config.new(allowed_user_ids: [ secure_user.id ]) + when "/test/securegroup" + PresenceChannel::Config.new(allowed_group_ids: [ secure_group.id ]) + when "/test/alloweduser" + PresenceChannel::Config.new(allowed_user_ids: [ user.id ]) + when "/test/allowedgroup" + PresenceChannel::Config.new(allowed_group_ids: [ group.id ]) + when "/test/countonly" + PresenceChannel::Config.new(public: true, count_only: true) + else + nil + end + end + end + + after do + PresenceChannel.clear_all! + PresenceChannel.unregister_prefix("test") + end + + describe "#update" do + it "works" do + sign_in(user) + client_id = SecureRandom.hex + + expect(ch1.user_ids).to eq([]) + expect(ch2.user_ids).to eq([]) + + post "/presence/update.json", params: { + client_id: client_id, + present_channels: [ch1.name, ch2.name] + } + expect(response.status).to eq(200) + expect(ch1.user_ids).to eq([user.id]) + expect(ch2.user_ids).to eq([user.id]) + + post "/presence/update.json", params: { + client_id: client_id, + present_channels: [ch1.name], + leave_channels: [ch2.name] + } + expect(response.status).to eq(200) + expect(ch1.user_ids).to eq([user.id]) + expect(ch2.user_ids).to eq([]) + + post "/presence/update.json", params: { + client_id: client_id, + present_channels: [], + leave_channels: [ch1.name] + } + expect(response.status).to eq(200) + expect(ch1.user_ids).to eq([]) + expect(ch2.user_ids).to eq([]) + end + + it "returns true/false based on channel existence/security" do + sign_in(user) + client_id = SecureRandom.hex + + expect(ch1.user_ids).to eq([]) + expect(secure_user_channel.user_ids).to eq([]) + expect(secure_group_channel.user_ids).to eq([]) + + post "/presence/update.json", params: { + client_id: client_id, + present_channels: [ + ch1.name, + secure_user_channel.name, + secure_group_channel.name, + allowed_user_channel.name, + allowed_group_channel.name, + "/test/nonexistent" + ] + } + expect(response.status).to eq(200) + expect(response.parsed_body).to eq({ + ch1.name => true, + secure_user_channel.name => false, + secure_group_channel.name => false, + allowed_user_channel.name => true, + allowed_group_channel.name => true, + "/test/nonexistent" => false, + }) + + expect(ch1.user_ids).to eq([user.id]) + expect(secure_user_channel.user_ids).to eq([]) + expect(secure_group_channel.user_ids).to eq([]) + expect(allowed_user_channel.user_ids).to eq([user.id]) + expect(allowed_group_channel.user_ids).to eq([user.id]) + end + end + + describe "#get" do + let(:user2) { Fabricate(:user) } + let(:user3) { Fabricate(:user) } + + it "works" do + get "/presence/get", params: { channel: ch1.name } + expect(response.status).to eq(200) + body = response.parsed_body + expect(body["users"]).to eq([]) + expect(body["count"]).to eq(0) + expect(body["last_message_id"]).to eq(MessageBus.last_id(ch1.message_bus_channel_name)) + + ch1.present(user_id: user.id, client_id: SecureRandom.hex) + ch1.present(user_id: user2.id, client_id: SecureRandom.hex) + ch1.present(user_id: user3.id, client_id: SecureRandom.hex) + + get "/presence/get", params: { channel: ch1.name } + body = response.parsed_body + expect(body["users"].map { |u| u["id"] }).to contain_exactly(user.id, user2.id, user3.id) + expect(body["users"][0].keys).to contain_exactly("avatar_template", "id", "name", "username") + expect(body["count"]).to eq(3) + expect(body["last_message_id"]).to eq(MessageBus.last_id(ch1.message_bus_channel_name)) + end + + it "respects the existence/security of the channel" do + sign_in user + + get "/presence/get", params: { channel: ch1.name } + expect(response.status).to eq(200) + + get "/presence/get", params: { channel: secure_user_channel.name } + expect(response.status).to eq(404) + + get "/presence/get", params: { channel: secure_group_channel.name } + expect(response.status).to eq(404) + + get "/presence/get", params: { channel: allowed_user_channel.name } + expect(response.status).to eq(200) + + get "/presence/get", params: { channel: allowed_group_channel.name } + expect(response.status).to eq(200) + + get "/presence/get", params: { channel: "/test/nonexistent" } + expect(response.status).to eq(404) + end + + it "works for count_only channels" do + get "/presence/get", params: { channel: count_only_channel.name } + expect(response.status).to eq(200) + expect(response.parsed_body.keys).to contain_exactly("count", "last_message_id") + expect(response.parsed_body["count"]).to eq(0) + + count_only_channel.present(user_id: user.id, client_id: "a") + + get "/presence/get", params: { channel: count_only_channel.name } + expect(response.status).to eq(200) + expect(response.parsed_body["count"]).to eq(1) + end + + end + +end