From 06b971cbe5c4af8835069090fff92d537cb37577 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 03:50:31 +0000 Subject: [PATCH 1/7] Initial plan From ae281e98e29d212b11ca6c9f25f53983a41a6e78 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 04:11:05 +0000 Subject: [PATCH 2/7] feat: add round-robin session queue scheduling across users - Add SESSION_QUEUE_MODE type and session_queue_mode config field - Modify dequeue() to support round-robin ordering when multiuser mode is active, serving each user in turn based on last-served timestamp - Add tests for FIFO and round-robin dequeue behavior Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --- .../app/services/config/config_default.py | 3 + .../session_queue/session_queue_sqlite.py | 46 +++- .../test_session_queue_dequeue.py | 214 ++++++++++++++++++ 3 files changed, 259 insertions(+), 4 deletions(-) create mode 100644 tests/app/services/session_queue/test_session_queue_dequeue.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 2cc2aaf273c..7f5039474e8 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -29,6 +29,7 @@ ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] +SESSION_QUEUE_MODE = Literal["FIFO", "round_robin"] CONFIG_SCHEMA_VERSION = "4.0.2" @@ -102,6 +103,7 @@ class InvokeAIAppConfig(BaseSettings): pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. clear_queue_on_startup: Empties session queue on startup. + session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting. allow_nodes: List of nodes to allow. Omit to allow all. deny_nodes: List of nodes to deny. Omit to deny none. node_cache_size: How many cached nodes to keep in memory. @@ -191,6 +193,7 @@ class InvokeAIAppConfig(BaseSettings): pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.") + session_queue_mode: SESSION_QUEUE_MODE = Field(default="round_robin", description="Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 4f46136fd79..fe7cc138bd1 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -155,9 +155,45 @@ async def enqueue_batch( return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql + config = self.__invoker.services.configuration + use_round_robin = config.multiuser and config.session_queue_mode == "round_robin" + + if use_round_robin: + query = """--sql + WITH user_last_served AS ( + -- Track when each user last had an item started, to determine whose turn it is. + SELECT user_id, MAX(started_at) AS last_served_at + FROM session_queue + WHERE started_at IS NOT NULL + GROUP BY user_id + ), + user_next_item AS ( + -- For each user, select their single best pending item (highest priority, then oldest). + SELECT + user_id, + item_id, + ROW_NUMBER() OVER ( + PARTITION BY user_id + ORDER BY priority DESC, item_id ASC + ) AS rn + FROM session_queue + WHERE status = 'pending' + ) + SELECT + sq.*, + u.display_name AS user_display_name, + u.email AS user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + JOIN user_next_item uni ON sq.item_id = uni.item_id AND uni.rn = 1 + LEFT JOIN user_last_served uls ON sq.user_id = uls.user_id + ORDER BY + COALESCE(uls.last_served_at, '1970-01-01') ASC, + sq.item_id ASC + LIMIT 1 + """ + else: + query = """--sql SELECT sq.*, u.display_name as user_display_name, @@ -170,7 +206,9 @@ def dequeue(self) -> Optional[SessionQueueItem]: sq.item_id ASC LIMIT 1 """ - ) + + with self._db.transaction() as cursor: + cursor.execute(query) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: return None diff --git a/tests/app/services/session_queue/test_session_queue_dequeue.py b/tests/app/services/session_queue/test_session_queue_dequeue.py new file mode 100644 index 00000000000..0f82f2babaa --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue.py @@ -0,0 +1,214 @@ +"""Tests for session queue dequeue() ordering: FIFO and round-robin modes.""" + +import json +import uuid +from typing import Optional + +import pytest +from pydantic_core import to_jsonable_python + +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState + +_EMPTY_SESSION_JSON = json.dumps(to_jsonable_python(GraphExecutionState(graph=Graph()).model_dump())) + + +@pytest.fixture +def session_queue_fifo(mock_invoker: Invoker) -> SqliteSessionQueue: + """Queue backed by a single-user (FIFO) invoker.""" + # Default config has multiuser=False, so FIFO is always used. + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +@pytest.fixture +def session_queue_round_robin(mock_invoker: Invoker) -> SqliteSessionQueue: + """Queue backed by a multiuser invoker with round_robin mode.""" + mock_invoker.services.configuration = InvokeAIAppConfig( + use_memory_db=True, + node_cache_size=0, + multiuser=True, + session_queue_mode="round_robin", + ) + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item( + session_queue: SqliteSessionQueue, + queue_id: str, + user_id: str, + priority: int = 0, +) -> int: + """Directly insert a minimal queue item and return its item_id.""" + session_id = str(uuid.uuid4()) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id, user_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (queue_id, _EMPTY_SESSION_JSON, session_id, batch_id, None, priority, None, None, None, None, user_id), + ) + return cursor.lastrowid # type: ignore[return-value] + + +def _dequeue_user_ids(session_queue: SqliteSessionQueue, count: int) -> list[Optional[str]]: + """Dequeue `count` items and return the list of user_ids in dequeue order.""" + result = [] + for _ in range(count): + item = session_queue.dequeue() + result.append(item.user_id if item is not None else None) + return result + + +# --------------------------------------------------------------------------- +# FIFO tests +# --------------------------------------------------------------------------- + + +def test_fifo_single_user_order(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: items from a single user are dequeued in insertion order.""" + queue_id = "default" + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_fifo, 3) + assert user_ids == ["user_a", "user_a", "user_a"] + + +def test_fifo_multi_user_preserves_insertion_order(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: jobs from multiple users are dequeued in strict insertion order, not interleaved.""" + queue_id = "default" + # Insert A1, A2, B1, C1, C2, A3 – FIFO should preserve this exact order. + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + _insert_queue_item(session_queue_fifo, queue_id, "user_b") + _insert_queue_item(session_queue_fifo, queue_id, "user_c") + _insert_queue_item(session_queue_fifo, queue_id, "user_c") + _insert_queue_item(session_queue_fifo, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_fifo, 6) + assert user_ids == ["user_a", "user_a", "user_b", "user_c", "user_c", "user_a"] + + +def test_fifo_priority_respected(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: higher-priority items are dequeued before lower-priority ones.""" + queue_id = "default" + _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=0) + _insert_queue_item(session_queue_fifo, queue_id, "user_a", priority=10) + + user_ids = _dequeue_user_ids(session_queue_fifo, 2) + # Both are user_a; second inserted item has higher priority and should come first. + assert user_ids == ["user_a", "user_a"] + + +def test_fifo_returns_none_when_empty(session_queue_fifo: SqliteSessionQueue) -> None: + """FIFO: dequeue returns None when the queue is empty.""" + assert session_queue_fifo.dequeue() is None + + +# --------------------------------------------------------------------------- +# Round-robin tests +# --------------------------------------------------------------------------- + + +def test_round_robin_interleaves_users(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: jobs from multiple users are interleaved one per user per round. + + Queue insertion order (matching the issue example): + A job 1, A job 2, B job 1, C job 1, C job 2, A job 3 + + Expected dequeue order: + A job 1, B job 1, C job 1, A job 2, C job 2, A job 3 + """ + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_b") + _insert_queue_item(session_queue_round_robin, queue_id, "user_c") + _insert_queue_item(session_queue_round_robin, queue_id, "user_c") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 6) + assert user_ids == ["user_a", "user_b", "user_c", "user_a", "user_c", "user_a"] + + +def test_round_robin_single_user_behaves_like_fifo(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin with only one user produces the same order as FIFO.""" + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 3) + assert user_ids == ["user_a", "user_a", "user_a"] + + +def test_round_robin_handles_user_joining_mid_queue(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: a user who joins later is correctly interleaved.""" + queue_id = "default" + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_a") + _insert_queue_item(session_queue_round_robin, queue_id, "user_b") + + user_ids = _dequeue_user_ids(session_queue_round_robin, 3) + # Round 1: A (oldest rank-1 item), B (rank-1 item) + # Round 2: A (rank-2 item) + assert user_ids == ["user_a", "user_b", "user_a"] + + +def test_round_robin_returns_none_when_empty(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: dequeue returns None when the queue is empty.""" + assert session_queue_round_robin.dequeue() is None + + +def test_round_robin_priority_within_user_respected(session_queue_round_robin: SqliteSessionQueue) -> None: + """Round-robin: within a single user's items, higher priority is dequeued first.""" + queue_id = "default" + # Insert low-priority item first, then high-priority for same user. + _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=0) + _insert_queue_item(session_queue_round_robin, queue_id, "user_a", priority=10) + _insert_queue_item(session_queue_round_robin, queue_id, "user_b", priority=0) + + # Round 1: user_a's best item (priority 10), user_b's only item. + # Round 2: user_a's remaining item (priority 0). + items = [] + for _ in range(3): + item = session_queue_round_robin.dequeue() + assert item is not None + items.append((item.user_id, item.priority)) + + assert items[0] == ("user_a", 10) + assert items[1] == ("user_b", 0) + assert items[2] == ("user_a", 0) + + +def test_round_robin_ignored_in_single_user_mode(mock_invoker: Invoker) -> None: + """When multiuser=False, round_robin config is ignored and FIFO is used.""" + mock_invoker.services.configuration = InvokeAIAppConfig( + use_memory_db=True, + node_cache_size=0, + multiuser=False, + session_queue_mode="round_robin", + ) + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + + queue_id = "default" + _insert_queue_item(queue, queue_id, "user_a") + _insert_queue_item(queue, queue_id, "user_a") + _insert_queue_item(queue, queue_id, "user_b") + + # FIFO order: user_a, user_a, user_b + user_ids = _dequeue_user_ids(queue, 3) + assert user_ids == ["user_a", "user_a", "user_b"] From 28c63b15f38f4d90f5c69da5a561553ac939ad60 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 25 Apr 2026 14:33:44 -0400 Subject: [PATCH 3/7] fix(multiuser): restore X/Y queue badge and cross-user queue list Three regressions from the multiuser isolation work in 33ec16de were preventing non-admin users from seeing the broader queue: 1. The "X/Y" pending badge collapsed to a single number because the backend stopped returning per-user counts and the frontend dropped the X/Y formatting. Restored user_pending/user_in_progress on SessionQueueStatus and the X/Y formatter; get_queue_status now takes an explicit is_admin flag for current-item visibility. 2. The queue list only showed the caller's own jobs because get_queue_item_ids filtered by user. Per-item field redaction already happens in list_all_queue_items / get_queue_items_by_item_ids, so the id list itself can be returned unfiltered. 3. After enqueue or status change in another user's batch, A's queue list, badge totals, and item statuses stayed stale until reload because QueueItemStatusChangedEvent and BatchEnqueuedEvent went only to user:{owner} + admin rooms. Now the full event still goes to those rooms, and a sanitized companion (user_id="redacted", identifiers and error fields stripped) is broadcast to the queue room with the owner and admin sids in skip_sid so they don't receive a clobbering duplicate. The frontend handler short-circuits the redacted variant to tag invalidation only, skipping per-session side effects. Co-Authored-By: Claude Opus 4.7 (1M context) --- invokeai/app/api/routers/session_queue.py | 20 +-- invokeai/app/api/sockets.py | 103 ++++++++++++--- .../app/services/config/config_default.py | 2 +- .../session_queue/session_queue_base.py | 15 ++- .../session_queue/session_queue_common.py | 6 + .../session_queue/session_queue_sqlite.py | 47 ++++--- .../queue/components/QueueCountBadge.tsx | 28 +++- .../frontend/web/src/services/api/schema.ts | 26 +++- .../src/services/events/setEventListeners.tsx | 18 +++ .../routers/test_multiuser_authorization.py | 120 +++++++++++++++--- 10 files changed, 317 insertions(+), 68 deletions(-) diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 41a5a411c7a..d62cac5095f 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -141,12 +141,11 @@ async def get_queue_item_ids( queue_id: str = Path(description="The queue id to perform this operation on"), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters. Non-admin users only see their own items.""" + """Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive; + per-item field redaction is performed when the items are fetched via list_all_queue_items or + get_queue_items_by_item_ids.""" try: - user_id = None if current_user.is_admin else current_user.user_id - return ApiDependencies.invoker.services.session_queue.get_queue_item_ids( - queue_id=queue_id, order_dir=order_dir, user_id=user_id - ) + return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}") @@ -436,10 +435,15 @@ async def get_queue_status( current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: - """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.""" + """Gets the status of the session queue. Returns global counts plus the calling user's own + pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the + current item's identifiers unless they own it.""" try: - user_id = None if current_user.is_admin else current_user.user_id - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id) + queue = ApiDependencies.invoker.services.session_queue.get_queue_status( + queue_id, + user_id=current_user.user_id, + is_admin=current_user.is_admin, + ) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 5783b804c0b..b02b5bbb067 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -260,20 +260,37 @@ async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) + def _owner_and_admin_sids(self, owner_user_id: str) -> list[str]: + """Sids belonging to the event's owner or to any admin. + + Used as `skip_sid` when broadcasting a sanitized companion event to the queue room, + so the owner and admins (who already received the full event) don't get a second + copy that would clobber their cache with redacted values. + """ + return [ + sid + for sid, info in self._socket_users.items() + if info.get("user_id") == owner_user_id or info.get("is_admin") + ] + async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): """Handle queue events with user isolation. - All queue item events (invocation events AND QueueItemStatusChangedEvent) are - private to the owning user and admins. They carry unsanitized user_id, batch_id, - session_id, origin, destination and error metadata, and must never be broadcast - to the whole queue room — otherwise any other authenticated subscriber could - observe cross-user queue activity. + Queue events split into two routing paths: - RecallParametersUpdatedEvent is also private to the owner + admins. + 1. The owner and admins receive the full unsanitized event in their `user:{id}` / + `admin` rooms. The full payload may include batch_id, session_id, origin, + destination, error metadata, etc. - BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and - is also routed privately. QueueClearedEvent is the only queue event that - is still broadcast to the whole queue room. + 2. For events that other authenticated users need to know about so their queue list + and badge counts stay in sync (QueueItemStatusChangedEvent and BatchEnqueuedEvent), + a sanitized companion event is also emitted to the full queue room with the + owner's and admins' sids in `skip_sid`. The companion uses `user_id="redacted"` + as a sentinel so the frontend handler knows to do tag invalidation only and skip + per-session side effects. + + InvocationEventBase events stay private (owner + admins only). RecallParametersUpdatedEvent + is also private. QueueClearedEvent has no user identity and is broadcast to the queue room. IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase inherits from QueueItemEventBase. The order of isinstance checks matters! @@ -302,10 +319,51 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room") - # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized - # user_id, batch_id, session_id, origin, destination and error metadata. - # They are private to the owning user + admins — never broadcast to the - # full queue room. + # QueueItemStatusChangedEvent: full to owner+admin, sanitized to everyone else in + # the queue room so their queue list, badge, and item caches refresh. + elif isinstance(event_data, QueueItemStatusChangedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + + sanitized = event_data.model_copy( + update={ + "user_id": "redacted", + "batch_id": "redacted", + "session_id": "redacted", + "origin": None, + "destination": None, + "error_type": None, + "error_message": None, + "error_traceback": None, + } + ) + # Strip identifying fields out of the embedded batch_status / queue_status too. + sanitized.batch_status = sanitized.batch_status.model_copy( + update={"batch_id": "redacted", "origin": None, "destination": None} + ) + sanitized.queue_status = sanitized.queue_status.model_copy( + update={ + "item_id": None, + "session_id": None, + "batch_id": None, + "user_pending": None, + "user_in_progress": None, + } + ) + await self._sio.emit( + event=event_name, + data=sanitized.model_dump(mode="json"), + room=event_data.queue_id, + skip_sid=self._owner_and_admin_sids(event_data.user_id), + ) + + logger.debug( + f"Emitted queue_item_status_changed: full to {user_room}+admin, sanitized to queue {event_data.queue_id}" + ) + + # Other queue item events (currently none beyond QueueItemStatusChangedEvent that + # carry user_id) stay private to owner + admins. elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"): user_room = f"user:{event_data.user_id}" await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) @@ -320,14 +378,25 @@ async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room") - # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and - # enqueued counts. Route it privately to the owner + admins so other - # users do not observe cross-user batch activity. + # BatchEnqueuedEvent: full to owner+admin, sanitized to everyone else in the queue + # room so their badge total and queue list pick up the new items. elif isinstance(event_data, BatchEnqueuedEvent): user_room = f"user:{event_data.user_id}" await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") - logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") + + sanitized = event_data.model_copy( + update={"user_id": "redacted", "batch_id": "redacted", "origin": None} + ) + await self._sio.emit( + event=event_name, + data=sanitized.model_dump(mode="json"), + room=event_data.queue_id, + skip_sid=self._owner_and_admin_sids(event_data.user_id), + ) + logger.debug( + f"Emitted batch_enqueued: full to {user_room}+admin, sanitized to queue {event_data.queue_id}" + ) else: # For remaining queue events (e.g. QueueClearedEvent) that do not diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 240371e981b..c99461b3fab 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -109,7 +109,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. - session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting. + session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin` clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. allow_nodes: List of nodes to allow. Omit to allow all. diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 14b93d97fc7..04bd81e3174 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -73,8 +73,19 @@ def is_full(self, queue_id: str) -> IsFullResult: pass @abstractmethod - def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: - """Gets the status of the queue. If user_id is provided, also includes user-specific counts.""" + def get_queue_status( + self, + queue_id: str, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> SessionQueueStatus: + """Gets the status of the queue. + + Always returns global pending/in_progress/etc. counts. When user_id is provided, also + populates user_pending and user_in_progress with that user's own counts (so the UI can + render an X/Y badge). When is_admin is False, the current item's identifiers are hidden + unless the calling user owns the in-progress item. + """ pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index d87221fbbae..7472ea07f63 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -309,6 +309,12 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") + user_pending: Optional[int] = Field( + default=None, description="Number of pending queue items for the calling user (multiuser only)" + ) + user_in_progress: Optional[int] = Field( + default=None, description="Number of in-progress queue items for the calling user (multiuser only)" + ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 2e7c9256947..326baed1b31 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -884,9 +884,25 @@ def get_queue_item_ids( return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids)) - def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: + def get_queue_status( + self, + queue_id: str, + user_id: Optional[str] = None, + is_admin: bool = False, + ) -> SessionQueueStatus: with self._db.transaction() as cursor: - # When user_id is provided (non-admin), only count that user's items + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? + GROUP BY status + """, + (queue_id,), + ) + counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + + user_counts_result: list[sqlite3.Row] = [] if user_id is not None: cursor.execute( """--sql @@ -897,24 +913,23 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess """, (queue_id, user_id), ) - else: - cursor.execute( - """--sql - SELECT status, count(*) - FROM session_queue - WHERE queue_id = ? - GROUP BY status - """, - (queue_id,), - ) - counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} - # For non-admin users, hide current item details if they don't own it - show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id) + user_pending: Optional[int] = None + user_in_progress: Optional[int] = None + if user_id is not None: + user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} + user_pending = user_counts.get("pending", 0) + user_in_progress = user_counts.get("in_progress", 0) + + # Non-admins cannot see the current item's identifiers unless they own it. + show_current_item = current_item is not None and ( + is_admin or user_id is None or current_item.user_id == user_id + ) return SessionQueueStatus( queue_id=queue_id, @@ -927,6 +942,8 @@ def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> Sess failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, + user_pending=user_pending, + user_in_progress=user_in_progress, ) def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: diff --git a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx index e8636466066..1ba2ffd572d 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueCountBadge.tsx @@ -1,4 +1,6 @@ import { Badge, Portal } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectIsAuthenticated } from 'features/auth/store/authSlice'; import type { RefObject } from 'react'; import { memo, useEffect, useMemo, useState } from 'react'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; @@ -10,14 +12,24 @@ type Props = { type SessionQueueStatus = components['schemas']['SessionQueueStatus']; +const hasUserCounts = (queueData: SessionQueueStatus): boolean => { + return ( + queueData.user_pending !== undefined && + queueData.user_pending !== null && + queueData.user_in_progress !== undefined && + queueData.user_in_progress !== null + ); +}; + /** - * Calculates the appropriate badge text based on queue status. + * Calculates the appropriate badge text based on queue status and authentication state. * Returns null if badge should be hidden. * - * In multiuser mode, the backend already scopes counts to the current user for non-admins, - * so pending + in_progress reflects the user's own queue items. + * In multiuser mode, the badge is "X/Y" where X is the calling user's pending+in_progress count + * and Y is the total across all users. In single-user mode (or when user counts are unavailable) + * the badge shows the total only. */ -const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null => { +const getBadgeText = (queueData: SessionQueueStatus | undefined, isAuthenticated: boolean): string | null => { if (!queueData) { return null; } @@ -28,18 +40,24 @@ const getBadgeText = (queueData: SessionQueueStatus | undefined): string | null return null; } + if (isAuthenticated && hasUserCounts(queueData)) { + const userPending = queueData.user_pending! + queueData.user_in_progress!; + return `${userPending}/${totalPending}`; + } + return totalPending.toString(); }; export const QueueCountBadge = memo(({ targetRef }: Props) => { const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null); + const isAuthenticated = useAppSelector(selectIsAuthenticated); const { queueData } = useGetQueueStatusQuery(undefined, { selectFromResult: (res) => ({ queueData: res.data?.queue, }), }); - const badgeText = useMemo(() => getBadgeText(queueData), [queueData]); + const badgeText = useMemo(() => getBadgeText(queueData, isAuthenticated), [queueData, isAuthenticated]); useEffect(() => { if (!targetRef.current) { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 4b8e4da95a5..f12ec2e538e 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1795,7 +1795,9 @@ export type paths = { }; /** * Get Queue Item Ids - * @description Gets all queue item ids that match the given parameters. Non-admin users only see their own items. + * @description Gets all queue item ids that match the given parameters. The IDs themselves are not sensitive; + * per-item field redaction is performed when the items are fetched via list_all_queue_items or + * get_queue_items_by_item_ids. */ get: operations["get_queue_item_ids"]; put?: never; @@ -2055,7 +2057,9 @@ export type paths = { }; /** * Get Queue Status - * @description Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it. + * @description Gets the status of the session queue. Returns global counts plus the calling user's own + * pending/in_progress counts (so the UI can show an X/Y badge). Non-admin users cannot see the + * current item's identifiers unless they own it. */ get: operations["get_queue_status"]; put?: never; @@ -15641,6 +15645,7 @@ export type components = { * force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). * pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. * max_queue_size: Maximum number of items in the session queue. + * session_queue_mode: Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.
Valid values: `FIFO`, `round_robin` * clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. * max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. * allow_nodes: List of nodes to allow. Omit to allow all. @@ -15972,6 +15977,13 @@ export type components = { * @default 10000 */ max_queue_size?: number; + /** + * Session Queue Mode + * @description Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting. + * @default round_robin + * @enum {string} + */ + session_queue_mode?: "FIFO" | "round_robin"; /** * Clear Queue On Startup * @description Empties session queue on startup. If true, disables `max_queue_history`. @@ -26807,6 +26819,16 @@ export type components = { * @description Total number of queue items */ total: number; + /** + * User Pending + * @description Number of pending queue items for the calling user (multiuser only) + */ + user_pending?: number | null; + /** + * User In Progress + * @description Number of in-progress queue items for the calling user (multiuser only) + */ + user_in_progress?: number | null; }; /** * SetupRequest diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 6771e9e7e00..d742ad09bf5 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -388,6 +388,24 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis }); socket.on('queue_item_status_changed', (data) => { + // Sanitized companion event sent to non-owner queue subscribers in multiuser mode. The + // backend sets user_id="redacted" and clears identifiers/error fields. We must not run + // payload-driven cache mutations or per-session side effects (node state reset, progress + // clear, completion bookkeeping) — those belong to the owner. Just invalidate queue tags + // so the non-owner's queue list and badge counts refetch with sanitized data. + if (data.user_id === 'redacted') { + log.trace({ data }, `Sanitized queue_item_status_changed for item ${data.item_id}`); + const tags: ApiTagDescription[] = [ + 'SessionQueueStatus', + 'SessionQueueItemIdList', + { type: 'SessionQueueItem', id: data.item_id }, + { type: 'SessionQueueItem', id: LIST_TAG }, + { type: 'SessionQueueItem', id: LIST_ALL_TAG }, + ]; + dispatch(queueApi.util.invalidateTags(tags)); + return; + } + if (finishedQueueItemIds.has(data.item_id)) { log.trace({ data }, `Received event for already-finished queue item ${data.item_id}`); return; diff --git a/tests/app/routers/test_multiuser_authorization.py b/tests/app/routers/test_multiuser_authorization.py index 85354c6a577..813b5170a09 100644 --- a/tests/app/routers/test_multiuser_authorization.py +++ b/tests/app/routers/test_multiuser_authorization.py @@ -1332,14 +1332,31 @@ def test_get_queue_status_hides_current_item_for_non_owner(self): assert status_obj.session_id is None assert status_obj.batch_id is None - def test_session_queue_status_no_user_fields(self): - """SessionQueueStatus should not have user_pending/user_in_progress fields anymore. - Non-admin users now get their own counts in the main pending/in_progress fields.""" + def test_session_queue_status_has_user_fields(self): + """SessionQueueStatus exposes user_pending/user_in_progress so the queue badge + can render an X/Y count (X = caller's jobs, Y = global total).""" from invokeai.app.services.session_queue.session_queue_common import SessionQueueStatus fields = set(SessionQueueStatus.model_fields.keys()) - assert "user_pending" not in fields - assert "user_in_progress" not in fields + assert "user_pending" in fields + assert "user_in_progress" in fields + + status_obj = SessionQueueStatus( + queue_id="default", + item_id=None, + session_id=None, + batch_id=None, + pending=5, + in_progress=1, + completed=0, + failed=0, + canceled=0, + total=6, + user_pending=2, + user_in_progress=1, + ) + assert status_obj.user_pending == 2 + assert status_obj.user_in_progress == 1 # =========================================================================== @@ -1707,8 +1724,11 @@ def test_batch_enqueued_event_carries_user_id(self) -> None: assert event.queue_id == "default" def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None: - """Verify that _handle_queue_event emits QueueItemStatusChangedEvent ONLY to - user:{user_id} and admin rooms, never to the queue_id room.""" + """_handle_queue_event must emit the FULL QueueItemStatusChangedEvent only to the + owner's user room and the admin room. A sanitized companion (user_id="redacted", + identifiers stripped) is also emitted to the queue_id room so other users' UIs can + refresh, with the owner's and admins' sids in skip_sid so they don't get a duplicate + that would clobber their cache.""" import asyncio from unittest.mock import AsyncMock @@ -1757,20 +1777,60 @@ def test_queue_item_status_changed_routed_privately(self, socketio: Any) -> None ), ) + # Track owner sid so we can verify skip_sid is honored + socketio._socket_users["sid-owner"] = {"user_id": "owner-xyz", "is_admin": False} + socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True} + socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False} + mock_emit = AsyncMock() socketio._sio.emit = mock_emit asyncio.run(socketio._handle_queue_event(("queue_item_status_changed", event))) - rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list] - assert "user:owner-xyz" in rooms_emitted_to - assert "admin" in rooms_emitted_to - # CRITICAL: must NOT emit to the queue_id room — that would leak to other users - assert "default" not in rooms_emitted_to + # Collect (room, payload, skip_sid) for each emit call + emits = [ + (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list + ] + + # Full event must go to owner room and admin room with original sensitive fields + owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-xyz"] + admin_emits = [(p, s) for r, p, s in emits if r == "admin"] + assert len(owner_emits) == 1 and len(admin_emits) == 1 + for payload, _ in owner_emits + admin_emits: + assert payload["user_id"] == "owner-xyz" + assert payload["batch_id"] == "batch-private" + assert payload["session_id"] == "sess-private" + assert payload["destination"] == "canvas" + + # A sanitized companion event must go to the queue_id room with sensitive fields cleared + queue_emits = [(p, s) for r, p, s in emits if r == "default"] + assert len(queue_emits) == 1, "expected exactly one sanitized emit to queue room" + sanitized_payload, skip_sid = queue_emits[0] + assert sanitized_payload["user_id"] == "redacted" + assert sanitized_payload["batch_id"] == "redacted" + assert sanitized_payload["session_id"] == "redacted" + assert sanitized_payload["origin"] is None + assert sanitized_payload["destination"] is None + assert sanitized_payload["error_type"] is None + assert sanitized_payload["batch_status"]["batch_id"] == "redacted" + assert sanitized_payload["batch_status"]["destination"] is None + assert sanitized_payload["queue_status"]["item_id"] is None + assert sanitized_payload["queue_status"]["batch_id"] is None + assert sanitized_payload["queue_status"]["user_pending"] is None + # Owner and admin sids must be skipped so they don't receive the duplicate + assert "sid-owner" in skip_sid + assert "sid-admin" in skip_sid + # Third-party user must NOT be skipped — they need the sanitized event + assert "sid-other" not in skip_sid + # Status (non-sensitive) is preserved so the non-owner UI knows what changed + assert sanitized_payload["status"] == "in_progress" + assert sanitized_payload["item_id"] == 1 def test_batch_enqueued_routed_privately(self, socketio: Any) -> None: - """Verify that _handle_queue_event emits BatchEnqueuedEvent ONLY to - user:{user_id} and admin rooms, never to the queue_id room.""" + """_handle_queue_event must emit the FULL BatchEnqueuedEvent only to the owner's + user room and the admin room. A sanitized companion (user_id="redacted", batch_id + and origin stripped) is also emitted to the queue_id room so other users' badge + totals refresh, with owner/admin sids in skip_sid.""" import asyncio from unittest.mock import AsyncMock @@ -1791,15 +1851,39 @@ def test_batch_enqueued_routed_privately(self, socketio: Any) -> None: ) event = BatchEnqueuedEvent.build(enqueue_result, user_id="owner-zzz") + socketio._socket_users["sid-owner"] = {"user_id": "owner-zzz", "is_admin": False} + socketio._socket_users["sid-admin"] = {"user_id": "admin-1", "is_admin": True} + socketio._socket_users["sid-other"] = {"user_id": "other-user", "is_admin": False} + mock_emit = AsyncMock() socketio._sio.emit = mock_emit asyncio.run(socketio._handle_queue_event(("batch_enqueued", event))) - rooms_emitted_to = [call.kwargs.get("room") for call in mock_emit.call_args_list] - assert "user:owner-zzz" in rooms_emitted_to - assert "admin" in rooms_emitted_to - assert "default" not in rooms_emitted_to + emits = [ + (c.kwargs.get("room"), c.kwargs.get("data"), c.kwargs.get("skip_sid")) for c in mock_emit.call_args_list + ] + + # Full event to owner + admin contains the real batch_id and origin + owner_emits = [(p, s) for r, p, s in emits if r == "user:owner-zzz"] + admin_emits = [(p, s) for r, p, s in emits if r == "admin"] + assert len(owner_emits) == 1 and len(admin_emits) == 1 + for payload, _ in owner_emits + admin_emits: + assert payload["user_id"] == "owner-zzz" + assert payload["batch_id"] == "batch-pvt" + assert payload["origin"] == "workflows" + + # Sanitized event to queue room: user/batch/origin redacted, owner+admin skipped + queue_emits = [(p, s) for r, p, s in emits if r == "default"] + assert len(queue_emits) == 1 + sanitized_payload, skip_sid = queue_emits[0] + assert sanitized_payload["user_id"] == "redacted" + assert sanitized_payload["batch_id"] == "redacted" + assert sanitized_payload["origin"] is None + assert sanitized_payload["enqueued"] == 5 # count is non-sensitive + assert "sid-owner" in skip_sid + assert "sid-admin" in skip_sid + assert "sid-other" not in skip_sid def test_queue_cleared_still_broadcast(self, socketio: Any) -> None: """QueueClearedEvent does not carry user identity and should still be broadcast From 8179b9de63ed14301e4072baaf13b7aa0e227526 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 25 Apr 2026 14:47:34 -0400 Subject: [PATCH 4/7] docs: regenerate settings.json for session_queue_mode Run via `pnpm run generate-docs-data`. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/generated/settings.json | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index 32140da667a..f0cc5c8961e 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -574,6 +574,20 @@ "type": "", "validation": {} }, + { + "category": "GENERATION", + "default": "round_robin", + "description": "Session queue mode. Use 'FIFO' for traditional first-in-first-out, or 'round_robin' to serve each user's jobs in turn. In single-user mode, FIFO is always used regardless of this setting.", + "env_var": "INVOKEAI_SESSION_QUEUE_MODE", + "literal_values": [ + "FIFO", + "round_robin" + ], + "name": "session_queue_mode", + "required": false, + "type": "typing.Literal['FIFO', 'round_robin']", + "validation": {} + }, { "category": "GENERATION", "default": false, From aaa379b1c2a9e185bd78edb120bb05eefb91de7e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 9 May 2026 11:18:59 -0400 Subject: [PATCH 5/7] fix(session_queue): restore user_pending/user_in_progress computation lost in merge The merge of main into this branch combined two conflicting refactors of get_queue_status: the branch added per-user user_pending/user_in_progress fields while main introduced acting_user_id for redaction. The merge kept the new structure plus the references in the return statement, but lost the lines that compute those variables, leaving user_counts_result populated but unused and raising NameError on every dequeue. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../app/services/session_queue/session_queue_sqlite.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 094272a213c..c1bb71e0b74 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -926,6 +926,13 @@ def get_queue_status( total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} + user_pending: Optional[int] = None + user_in_progress: Optional[int] = None + if user_id is not None: + user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} + user_pending = user_counts.get("pending", 0) + user_in_progress = user_counts.get("in_progress", 0) + # Redaction is decided from the same current_item snapshot used to embed identifiers, # so a concurrent transition (e.g. B finishing while A's status changes) cannot leave # stale identifiers in the result. user_id (count filter) and acting_user_id From 932eeedc33421b6f21ae5744ace83974bdbf885f Mon Sep 17 00:00:00 2001 From: Valeri Che <38873282+DustyShoe@users.noreply.github.com> Date: Thu, 14 May 2026 15:41:29 +0300 Subject: [PATCH 6/7] Feat(canvas): Replace Rectangle tool with new Shapes tool (#9082) * Feat(Canvas): Replace Rectangle tool with multifunctional Shapes tool. * Fix: Tweaked icon size on top bar * Fix: also tweaked SVGs for Gradient tool to align with unified 16px icon size * Feat: added freehand shape * fix(canvas): remove duplicate lasso payload export after rebase * `fix(canvas): clear polygon preview stroke on commit` * chore: remove temporary codex artifact * chore: format with prettier * fix(canvas): preserve shapes sessions across view switch * chore: format with prettier * add: constrain rectangles to squares with shift * fix(canvas): refine shapes space and alt interactions * fix(canvas): preserve polygon sessions across temporary tool switches * refactor(i18n): reuse lasso labels for shapes polygon modes * fix(i18n): merge shapes locale additions with modifier hints * feat(canvas): add shape-specific modifier hints and docs * fix(canvas): refine shape modifier hints and toolbar overflow * fix(canvas): keep toolbar overflow clipped to the right * fix: Escape while panning while drawing should exit shape tool --------- Co-authored-by: Alexander Eichhorn Co-authored-by: dunkeroni --- .../src/content/docs/features/shapes-tool.mdx | 96 ++ invokeai/frontend/web/public/locales/en.json | 12 +- .../components/Tool/GradientIcons.tsx | 16 +- .../components/Tool/ToolChooser.tsx | 4 +- .../components/Tool/ToolShapeTypeToggle.tsx | 65 ++ ...oolRectButton.tsx => ToolShapesButton.tsx} | 21 +- .../components/Toolbar/CanvasToolbar.tsx | 58 +- .../CanvasEntityBufferObjectRenderer.ts | 60 +- .../CanvasEntityObjectRenderer.ts | 38 +- .../konva/CanvasObject/CanvasObjectOval.ts | 88 ++ .../konva/CanvasObject/CanvasObjectPolygon.ts | 113 +++ .../konva/CanvasObject/CanvasObjectRect.ts | 6 +- .../controlLayers/konva/CanvasObject/types.ts | 8 + .../konva/CanvasStateApiModule.ts | 10 +- .../konva/CanvasTool/CanvasRectToolModule.ts | 102 -- .../konva/CanvasTool/CanvasShapeToolModule.ts | 895 ++++++++++++++++++ .../konva/CanvasTool/CanvasToolModule.ts | 186 +++- .../konva/CanvasTool/toolHotkeys.test.ts | 70 ++ .../konva/CanvasTool/toolHotkeys.ts | 65 ++ .../store/canvasSettingsSlice.ts | 11 + .../controlLayers/store/canvasSlice.ts | 12 +- .../src/features/controlLayers/store/types.ts | 27 +- .../layouts/DockviewCanvasHeaderActions.tsx | 18 +- .../layouts/canvasToolModifierHints.test.ts | 152 +-- .../ui/layouts/canvasToolModifierHints.ts | 45 +- 25 files changed, 1897 insertions(+), 281 deletions(-) create mode 100644 docs/src/content/docs/features/shapes-tool.mdx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx rename invokeai/frontend/web/src/features/controlLayers/components/Tool/{ToolRectButton.tsx => ToolShapesButton.tsx} (58%) create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts delete mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts diff --git a/docs/src/content/docs/features/shapes-tool.mdx b/docs/src/content/docs/features/shapes-tool.mdx new file mode 100644 index 00000000000..6ad795aed7a --- /dev/null +++ b/docs/src/content/docs/features/shapes-tool.mdx @@ -0,0 +1,96 @@ +--- +title: Shapes Tool +description: Learn how to draw filled shapes on raster and inpaint mask layers with the Shapes tool. +lastUpdated: 2026-05-11 +--- + +import { Card, CardGrid } from '@astrojs/starlight/components'; + +The Shapes tool is a general-purpose filled-shape drawing tool for the canvas. It replaces the old Rectangle tool and +adds four shape modes under a single toolbar button: + +- **Rect** +- **Oval** +- **Polygon** +- **Freehand** + +You can activate the Shapes tool from the canvas toolbar or with the default hotkey U. + +## Where Shapes Draws + +Shapes always draws into the **active raster target**: + +- On a regular raster layer, Shapes adds filled pixels to that layer. +- On an active inpaint mask layer, Shapes draws directly into the mask. + +:::note +Shapes overlaps with some Lasso workflows on mask layers, but the tools are not identical. Lasso is still the more +specialized masking tool and can create a new mask layer automatically when one does not already exist. +::: + +## Common Behavior + +- Shapes preview live while you draw. +- The fill color uses the current active color. +- The active color's alpha is respected when adding pixels. +- Hold Ctrl on Windows/Linux or Cmd on macOS to switch to **subtractive** mode and cut pixels + out of the active layer. +- In subtractive mode, alpha is ignored and the shape fully clears pixels. +- Press Esc to cancel the current shape session. + +:::tip +When subtractive mode is active, the canvas cursor shows a small minus badge so you can tell at a glance that the next +shape will erase instead of fill. +::: + +## Shape Modes + + + + Drag to draw a rectangle. Hold Shift to constrain to a square. Hold Alt to draw from the + center instead of from a corner. + + + Drag to draw an ellipse. Hold Shift to constrain to a perfect circle. Hold Alt to draw from + the center. + + + Click to place vertices. Click the first point to close and commit the shape. Hold Shift to snap the + pending edge to horizontal, vertical, and 45 degree angles. + + + Click and drag to sketch a filled freehand contour. Release the pointer to commit the shape. + + + +## Moving and Panning During Drawing + +The Shapes tool supports different Space behavior depending on the current mode: + +- **Rect / Oval:** While the pointer is still down, hold Space to move the uncommitted shape instead of + resizing it. Release Space to continue resizing. +- **Polygon / Freehand:** Hold Space during an active session to pan the viewport without discarding the + unfinished shape. + +This is especially useful when drawing large shapes that extend beyond the current viewport. + +## Color Picking While Using Shapes + +The Alt key behaves differently depending on the active Shapes mode: + +- **Rect / Oval:** Before you start dragging, Alt can be used for the temporary color-picker quick-switch. + Once a drag is active, Alt is reserved for drawing from the center. +- **Polygon:** Alt remains available for the temporary color-picker quick-switch between vertex placements. +- **Freehand:** Alt is available before the stroke starts, but not during an active stroke. + +## Practical Examples + +- Use **Rect** or **Oval** to block in clean mask regions quickly. +- Use **Polygon** when you need straight edges and deliberate corner placement. +- Use **Freehand** for irregular organic regions. +- Use **subtractive mode** to cut holes back out of an existing raster or mask layer. + +## Summary + +The Shapes tool is the fastest way to add filled geometric or freeform regions to canvas layers. Use it for structured +fills, mask authoring, and precise subtractive edits without switching away from the current raster target. diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index d99bb04a631..c164d1dafe1 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -728,8 +728,8 @@ "desc": "Select the move tool." }, "selectRectTool": { - "title": "Rect Tool", - "desc": "Select the rect tool." + "title": "Shapes Tool", + "desc": "Select the shapes tool." }, "selectLassoTool": { "title": "Lasso Tool", @@ -2881,6 +2881,10 @@ "polygon": "Polygon", "polygonHint": "Click to add points, click the first point to close." }, + "shape": { + "rect": "Rect", + "oval": "Oval" + }, "modifierHints": { "keys": { "control": "Ctrl", @@ -2896,11 +2900,12 @@ }, "labels": { "pan": "Pan", + "moveShape": "Move shape", "pickColor": "Pick color", "straightLine": "Straight line", "resizeBrush": "Resize brush", "resizeEraser": "Resize eraser", - "subtractMask": "Subtract mask", + "erase": "Erase", "snap45Degrees": "Snap to 45deg", "lockAspectRatio": "Lock ratio", "unlockAspectRatio": "Unlock ratio", @@ -2917,6 +2922,7 @@ "tool": { "brush": "Brush", "eraser": "Eraser", + "shapes": "Shapes", "rectangle": "Rectangle", "lasso": "Lasso", "gradient": "Gradient", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx index b09e46d7320..61074015e76 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/GradientIcons.tsx @@ -32,7 +32,7 @@ export const GradientLinearIcon = memo(() => { const id = useId(); const gradientId = `${id}-gradient-linear-diagonal`; return ( - + @@ -40,15 +40,15 @@ export const GradientLinearIcon = memo(() => { ); @@ -59,7 +59,7 @@ export const GradientRadialIcon = memo(() => { const id = useId(); const gradientId = `${id}-gradient-radial`; return ( - + @@ -67,13 +67,13 @@ export const GradientRadialIcon = memo(() => { ); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx index 30d82722072..c0291f8e587 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx @@ -5,7 +5,7 @@ import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/To import { ToolGradientButton } from 'features/controlLayers/components/Tool/ToolGradientButton'; import { ToolLassoButton } from 'features/controlLayers/components/Tool/ToolLassoButton'; import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton'; -import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton'; +import { ToolShapesButton } from 'features/controlLayers/components/Tool/ToolShapesButton'; import { ToolTextButton } from 'features/controlLayers/components/Tool/ToolTextButton'; import React from 'react'; @@ -18,7 +18,7 @@ export const ToolChooser: React.FC = () => { - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx new file mode 100644 index 00000000000..2e4530e2e27 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapeTypeToggle.tsx @@ -0,0 +1,65 @@ +import { ButtonGroup, IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectShapeType, settingsShapeTypeChanged } from 'features/controlLayers/store/canvasSettingsSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCircleBold, PiPolygonBold, PiRectangleBold, PiScribbleLoopBold } from 'react-icons/pi'; + +export const ToolShapeTypeToggle = memo(() => { + const { t } = useTranslation(); + const shapeType = useAppSelector(selectShapeType); + const dispatch = useAppDispatch(); + + const onRectClick = useCallback(() => dispatch(settingsShapeTypeChanged('rect')), [dispatch]); + const onOvalClick = useCallback(() => dispatch(settingsShapeTypeChanged('oval')), [dispatch]); + const onPolygonClick = useCallback(() => dispatch(settingsShapeTypeChanged('polygon')), [dispatch]); + const onFreehandClick = useCallback(() => dispatch(settingsShapeTypeChanged('freehand')), [dispatch]); + + const rectLabel = t('controlLayers.shape.rect', { defaultValue: 'Rect' }); + const ovalLabel = t('controlLayers.shape.oval', { defaultValue: 'Oval' }); + const polygonLabel = t('controlLayers.lasso.polygon', { defaultValue: 'Polygon' }); + const freehandLabel = t('controlLayers.lasso.freehand', { defaultValue: 'Freehand' }); + + return ( + + + } + colorScheme={shapeType === 'rect' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onRectClick} + /> + + + } + colorScheme={shapeType === 'oval' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onOvalClick} + /> + + + } + colorScheme={shapeType === 'polygon' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onPolygonClick} + /> + + + } + colorScheme={shapeType === 'freehand' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={onFreehandClick} + /> + + + ); +}); + +ToolShapeTypeToggle.displayName = 'ToolShapeTypeToggle'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx similarity index 58% rename from invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx rename to invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx index 93029390883..3f6c546d2cf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolRectButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolShapesButton.tsx @@ -3,32 +3,33 @@ import { useSelectTool, useToolIsSelected } from 'features/controlLayers/compone import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiRectangleBold } from 'react-icons/pi'; +import { PiShapesBold } from 'react-icons/pi'; -export const ToolRectButton = memo(() => { +export const ToolShapesButton = memo(() => { const { t } = useTranslation(); const isSelected = useToolIsSelected('rect'); - const selectRect = useSelectTool('rect'); + const selectShapes = useSelectTool('rect'); + const label = t('controlLayers.tool.shapes', { defaultValue: 'Shapes' }); useRegisteredHotkeys({ id: 'selectRectTool', category: 'canvas', - callback: selectRect, + callback: selectShapes, options: { enabled: !isSelected }, - dependencies: [isSelected, selectRect], + dependencies: [isSelected, selectShapes], }); return ( - + } + aria-label={`${label} (U)`} + icon={} colorScheme={isSelected ? 'invokeBlue' : 'base'} variant="solid" - onClick={selectRect} + onClick={selectShapes} /> ); }); -ToolRectButton.displayName = 'ToolRectButton'; +ToolShapesButton.displayName = 'ToolShapesButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index bee8f5d1a34..bd72306e2e3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -7,6 +7,7 @@ import { ToolGradientClipToggle } from 'features/controlLayers/components/Tool/T import { ToolGradientModeToggle } from 'features/controlLayers/components/Tool/ToolGradientModeToggle'; import { ToolLassoModeToggle } from 'features/controlLayers/components/Tool/ToolLassoModeToggle'; import { ToolOptionsRowContainer } from 'features/controlLayers/components/Tool/ToolOptionsRowContainer'; +import { ToolShapeTypeToggle } from 'features/controlLayers/components/Tool/ToolShapeTypeToggle'; import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidthPicker'; import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton'; import { CanvasToolbarFitBboxToMasksButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToMasksButton'; @@ -35,6 +36,7 @@ import { memo, useMemo } from 'react'; export const CanvasToolbar = memo(() => { const isBrushSelected = useToolIsSelected('brush'); const isEraserSelected = useToolIsSelected('eraser'); + const isShapeSelected = useToolIsSelected('rect'); const isTextSelected = useToolIsSelected('text'); const isLassoSelected = useToolIsSelected('lasso'); const isGradientSelected = useToolIsSelected('gradient'); @@ -56,9 +58,28 @@ export const CanvasToolbar = memo(() => { useCanvasToggleBboxHotkey(); return ( - - + + + {isShapeSelected && ( + + + + )} {isGradientSelected && ( @@ -72,21 +93,24 @@ export const CanvasToolbar = memo(() => { )} {isTextSelected ? : showToolWithPicker && } - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + ); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts index 9941761a2ee..b282580fec9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts @@ -9,8 +9,11 @@ import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; +import { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval'; +import { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; +import { shouldPreserveSuspendableShapesSession } from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import Konva from 'konva'; import type { Logger } from 'roarr'; @@ -83,6 +86,21 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.subscriptions.add( this.manager.tool.$tool.listen(() => { if (this.hasBuffer() && !this.manager.$isBusy.get()) { + const isTemporaryShapesToolSwitch = shouldPreserveSuspendableShapesSession( + this.manager.tool.$tool.get(), + this.manager.tool.$toolBuffer.get(), + this.manager.tool.tools.rect.hasSuspendableSession() + ); + + if (isTemporaryShapesToolSwitch) { + return; + } + + if (this.state?.type === 'polygon' && this.state.previewPoint) { + this.clearBuffer(); + return; + } + this.commitBuffer(); } }) @@ -153,6 +171,24 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.konva.group.add(this.renderer.konva.group); } + didRender = this.renderer.update(this.state, true); + } else if (this.state.type === 'oval') { + assert(this.renderer instanceof CanvasObjectOval || !this.renderer); + + if (!this.renderer) { + this.renderer = new CanvasObjectOval(this.state, this); + this.konva.group.add(this.renderer.konva.group); + } + + didRender = this.renderer.update(this.state, true); + } else if (this.state.type === 'polygon') { + assert(this.renderer instanceof CanvasObjectPolygon || !this.renderer); + + if (!this.renderer) { + this.renderer = new CanvasObjectPolygon(this.state, this); + this.konva.group.add(this.renderer.konva.group); + } + didRender = this.renderer.update(this.state, true); } else if (this.state.type === 'lasso') { assert(this.renderer instanceof CanvasObjectLasso || !this.renderer); @@ -240,28 +276,40 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.log.trace({ buffer: this.renderer.repr() }, 'Committing buffer'); + let committedState = this.state; + + // Polygon previews render an outline while they are still live in the buffer. + // Clear that preview state before adopting the renderer into the persistent object group. + if (committedState.type === 'polygon' && this.renderer instanceof CanvasObjectPolygon) { + committedState = { ...committedState, previewPoint: undefined }; + this.state = null; + this.renderer.update(committedState, true); + } + // Move the buffer to the persistent objects group/renderers this.parent.renderer.adoptObjectRenderer(this.renderer); if (pushToState) { const entityIdentifier = this.parent.entityIdentifier; - switch (this.state.type) { + switch (committedState.type) { case 'brush_line': case 'brush_line_with_pressure': - this.manager.stateApi.addBrushLine({ entityIdentifier, brushLine: this.state }); + this.manager.stateApi.addBrushLine({ entityIdentifier, brushLine: committedState }); break; case 'eraser_line': case 'eraser_line_with_pressure': - this.manager.stateApi.addEraserLine({ entityIdentifier, eraserLine: this.state }); + this.manager.stateApi.addEraserLine({ entityIdentifier, eraserLine: committedState }); break; case 'rect': - this.manager.stateApi.addRect({ entityIdentifier, rect: this.state }); + case 'oval': + case 'polygon': + this.manager.stateApi.addShape({ entityIdentifier, shape: committedState }); break; case 'lasso': - this.manager.stateApi.addLasso({ entityIdentifier, lasso: this.state }); + this.manager.stateApi.addLasso({ entityIdentifier, lasso: committedState }); break; case 'gradient': - this.manager.stateApi.addGradient({ entityIdentifier, gradient: this.state }); + this.manager.stateApi.addGradient({ entityIdentifier, gradient: committedState }); break; } } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts index 903ccaa772c..f62ce3f9822 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts @@ -11,6 +11,8 @@ import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; +import { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval'; +import { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; @@ -398,6 +400,26 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { this.konva.objectGroup.add(renderer.konva.group); } + didRender = renderer.update(objectState, force || isFirstRender); + } else if (objectState.type === 'oval') { + assert(renderer instanceof CanvasObjectOval || !renderer); + + if (!renderer) { + renderer = new CanvasObjectOval(objectState, this); + this.renderers.set(renderer.id, renderer); + this.konva.objectGroup.add(renderer.konva.group); + } + + didRender = renderer.update(objectState, force || isFirstRender); + } else if (objectState.type === 'polygon') { + assert(renderer instanceof CanvasObjectPolygon || !renderer); + + if (!renderer) { + renderer = new CanvasObjectPolygon(objectState, this); + this.renderers.set(renderer.id, renderer); + this.konva.objectGroup.add(renderer.konva.group); + } + didRender = renderer.update(objectState, force || isFirstRender); } else if (objectState.type === 'lasso') { assert(renderer instanceof CanvasObjectLasso || !renderer); @@ -455,10 +477,24 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { renderer instanceof CanvasObjectEraserLine || renderer instanceof CanvasObjectEraserLineWithPressure; const isSubtractingLasso = renderer instanceof CanvasObjectLasso && renderer.state.compositeOperation === 'destination-out'; + const isSubtractRect = + renderer instanceof CanvasObjectRect && renderer.state.compositeOperation === 'destination-out'; + const isSubtractOval = + renderer instanceof CanvasObjectOval && renderer.state.compositeOperation === 'destination-out'; + const isSubtractPolygon = + renderer instanceof CanvasObjectPolygon && renderer.state.compositeOperation === 'destination-out'; const isImage = renderer instanceof CanvasObjectImage; const imageIgnoresTransparency = isImage && renderer.state.usePixelBbox === false; const hasClip = renderer instanceof CanvasObjectBrushLine && renderer.state.clip; - if (isEraserLine || isSubtractingLasso || hasClip || (isImage && !imageIgnoresTransparency)) { + if ( + isEraserLine || + isSubtractingLasso || + isSubtractRect || + isSubtractOval || + isSubtractPolygon || + hasClip || + (isImage && !imageIgnoresTransparency) + ) { needsPixelBbox = true; break; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts new file mode 100644 index 00000000000..8c06268f768 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectOval.ts @@ -0,0 +1,88 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import { deepClone } from 'common/util/deepClone'; +import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; +import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasOvalState } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import type { Logger } from 'roarr'; + +export class CanvasObjectOval extends CanvasModuleBase { + readonly type = 'object_oval'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer; + readonly manager: CanvasManager; + readonly log: Logger; + + state: CanvasOvalState; + konva: { + group: Konva.Group; + ellipse: Konva.Ellipse; + }; + + constructor(state: CanvasOvalState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) { + super(); + this.id = state.id; + this.parent = parent; + this.manager = parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug({ state }, 'Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + ellipse: new Konva.Ellipse({ + name: `${this.type}:ellipse`, + listening: false, + radiusX: 0, + radiusY: 0, + perfectDrawEnabled: false, + }), + }; + this.konva.group.add(this.konva.ellipse); + this.state = state; + } + + update(state: CanvasOvalState, force = false): boolean { + if (force || this.state !== state) { + this.log.trace({ state }, 'Updating oval'); + const { rect, color, compositeOperation } = state; + const fill = compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(color); + this.konva.ellipse.setAttrs({ + x: rect.x + rect.width / 2, + y: rect.y + rect.height / 2, + radiusX: rect.width / 2, + radiusY: rect.height / 2, + fill, + globalCompositeOperation: compositeOperation, + }); + this.state = state; + return true; + } + + return false; + } + + setVisibility(isVisible: boolean): void { + this.log.trace({ isVisible }, 'Setting oval visibility'); + this.konva.group.visible(isVisible); + } + + destroy = () => { + this.log.debug('Destroying module'); + this.konva.group.destroy(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + parent: this.parent.id, + state: deepClone(this.state), + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts new file mode 100644 index 00000000000..dc54811569b --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectPolygon.ts @@ -0,0 +1,113 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import { deepClone } from 'common/util/deepClone'; +import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; +import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasPolygonState, RgbaColor } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import type { Logger } from 'roarr'; + +const getPreviewStrokeColor = (color: RgbaColor) => rgbaColorToString({ ...color, a: Math.max(color.a, 0.9) }); + +export class CanvasObjectPolygon extends CanvasModuleBase { + readonly type = 'object_polygon'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer; + readonly manager: CanvasManager; + readonly log: Logger; + + state: CanvasPolygonState; + konva: { + group: Konva.Group; + fillPolygon: Konva.Line; + previewStroke: Konva.Line; + }; + + constructor(state: CanvasPolygonState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) { + super(); + this.id = state.id; + this.parent = parent; + this.manager = parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug({ state }, 'Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + fillPolygon: new Konva.Line({ + name: `${this.type}:fill_polygon`, + listening: false, + closed: true, + strokeEnabled: false, + perfectDrawEnabled: false, + }), + previewStroke: new Konva.Line({ + name: `${this.type}:preview_stroke`, + listening: false, + closed: false, + fillEnabled: false, + lineCap: 'round', + lineJoin: 'round', + perfectDrawEnabled: false, + strokeWidth: 1, + }), + }; + this.konva.group.add(this.konva.fillPolygon, this.konva.previewStroke); + this.state = state; + } + + update(state: CanvasPolygonState, force = false): boolean { + if (force || this.state !== state) { + this.log.trace({ state }, 'Updating polygon'); + const combinedPoints = state.previewPoint + ? [...state.points, state.previewPoint.x, state.previewPoint.y] + : state.points; + const hasRenderablePolygon = combinedPoints.length >= 6; + const isLiveBufferPreview = this.parent.type === 'buffer_renderer' && this.parent.state?.id === state.id; + const fill = + state.compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(state.color); + + this.konva.fillPolygon.setAttrs({ + points: combinedPoints, + visible: hasRenderablePolygon, + fill, + globalCompositeOperation: state.compositeOperation, + }); + + this.konva.previewStroke.setAttrs({ + points: combinedPoints, + visible: (Boolean(state.previewPoint) || isLiveBufferPreview) && combinedPoints.length >= 4, + stroke: getPreviewStrokeColor(state.color), + globalCompositeOperation: 'source-over', + }); + + this.state = state; + return true; + } + + return false; + } + + setVisibility(isVisible: boolean): void { + this.log.trace({ isVisible }, 'Setting polygon visibility'); + this.konva.group.visible(isVisible); + } + + destroy = () => { + this.log.debug('Destroying module'); + this.konva.group.destroy(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + parent: this.parent.id, + state: deepClone(this.state), + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts index 1ac8e5b5f37..e879dcd35ab 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectRect.ts @@ -46,13 +46,15 @@ export class CanvasObjectRect extends CanvasModuleBase { this.isFirstRender = false; this.log.trace({ state }, 'Updating rect'); - const { rect, color } = state; + const { rect, color, compositeOperation } = state; + const fill = compositeOperation === 'destination-out' ? 'rgba(255,255,255,1)' : rgbaColorToString(color); this.konva.rect.setAttrs({ x: rect.x, y: rect.y, width: rect.width, height: rect.height, - fill: rgbaColorToString(color), + fill, + globalCompositeOperation: compositeOperation, }); this.state = state; return true; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts index f193c0b391e..620842a9426 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts @@ -5,6 +5,8 @@ import type { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/ import type { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import type { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; import type { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; +import type { CanvasObjectOval } from 'features/controlLayers/konva/CanvasObject/CanvasObjectOval'; +import type { CanvasObjectPolygon } from 'features/controlLayers/konva/CanvasObject/CanvasObjectPolygon'; import type { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { CanvasBrushLineState, @@ -14,6 +16,8 @@ import type { CanvasGradientState, CanvasImageState, CanvasLassoState, + CanvasOvalState, + CanvasPolygonState, CanvasRectState, } from 'features/controlLayers/store/types'; @@ -28,6 +32,8 @@ export type AnyObjectRenderer = | CanvasObjectEraserLineWithPressure | CanvasObjectRect | CanvasObjectLasso + | CanvasObjectOval + | CanvasObjectPolygon | CanvasObjectImage | CanvasObjectGradient; /** @@ -41,4 +47,6 @@ export type AnyObjectState = | CanvasImageState | CanvasRectState | CanvasLassoState + | CanvasOvalState + | CanvasPolygonState | CanvasGradientState; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index 7d4c76b0c06..26abd908e51 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -25,8 +25,8 @@ import { entityMovedBy, entityMovedTo, entityRasterized, - entityRectAdded, entityReset, + entityShapeAdded, inpaintMaskAdded, rasterLayerAdded, rgAdded, @@ -48,7 +48,7 @@ import type { EntityMovedByPayload, EntityMovedToPayload, EntityRasterizedPayload, - EntityRectAddedPayload, + EntityShapeAddedPayload, Rect, RgbaColor, } from 'features/controlLayers/store/types'; @@ -171,10 +171,10 @@ export class CanvasStateApiModule extends CanvasModuleBase { }; /** - * Adds a rectangle to an entity, pushing state to redux. + * Adds a shape to an entity, pushing state to redux. */ - addRect = (arg: EntityRectAddedPayload) => { - this.store.dispatch(entityRectAdded(arg)); + addShape = (arg: EntityShapeAddedPayload) => { + this.store.dispatch(entityShapeAdded(arg)); }; /** diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts deleted file mode 100644 index 3f64b0c2fc1..00000000000 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasRectToolModule.ts +++ /dev/null @@ -1,102 +0,0 @@ -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; -import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; -import { floorCoord, getPrefixedId, offsetCoord } from 'features/controlLayers/konva/util'; -import type { KonvaEventObject } from 'konva/lib/Node'; -import type { Logger } from 'roarr'; - -export class CanvasRectToolModule extends CanvasModuleBase { - readonly type = 'rect_tool'; - readonly id: string; - readonly path: string[]; - readonly parent: CanvasToolModule; - readonly manager: CanvasManager; - readonly log: Logger; - - constructor(parent: CanvasToolModule) { - super(); - this.id = getPrefixedId(this.type); - this.parent = parent; - this.manager = this.parent.manager; - this.path = this.manager.buildPath(this); - this.log = this.manager.buildLogger(this); - - this.log.debug('Creating module'); - } - - syncCursorStyle = () => { - this.manager.stage.setCursor('crosshair'); - }; - - onStagePointerDown = async (_e: KonvaEventObject) => { - const cursorPos = this.parent.$cursorPos.get(); - const isPrimaryPointerDown = this.parent.$isPrimaryPointerDown.get(); - const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - - if (!cursorPos || !isPrimaryPointerDown || !selectedEntity) { - /** - * Can't do anything without: - * - A cursor position: the cursor is not on the stage - * - The mouse is down: the user is not drawing - * - A selected entity: there is no entity to draw on - */ - return; - } - - const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position); - - await selectedEntity.bufferRenderer.setBuffer({ - id: getPrefixedId('rect'), - type: 'rect', - rect: { x: Math.round(normalizedPoint.x), y: Math.round(normalizedPoint.y), width: 0, height: 0 }, - color: this.manager.stateApi.getCurrentColor(), - }); - }; - - onStagePointerUp = (_e: KonvaEventObject) => { - const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - if (!selectedEntity) { - return; - } - - if (selectedEntity.bufferRenderer.state?.type === 'rect' && selectedEntity.bufferRenderer.hasBuffer()) { - selectedEntity.bufferRenderer.commitBuffer(); - } else { - selectedEntity.bufferRenderer.clearBuffer(); - } - }; - - onStagePointerMove = async (_e: KonvaEventObject) => { - const cursorPos = this.parent.$cursorPos.get(); - - if (!cursorPos) { - return; - } - - if (!this.parent.$isPrimaryPointerDown.get()) { - return; - } - - const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - - if (!selectedEntity) { - return; - } - - const bufferState = selectedEntity.bufferRenderer.state; - - if (!bufferState) { - return; - } - - if (bufferState.type !== 'rect') { - return; - } - - const normalizedPoint = offsetCoord(cursorPos.relative, selectedEntity.state.position); - const alignedPoint = floorCoord(normalizedPoint); - bufferState.rect.width = Math.round(alignedPoint.x - bufferState.rect.x); - bufferState.rect.height = Math.round(alignedPoint.y - bufferState.rect.y); - await selectedEntity.bufferRenderer.setBuffer(bufferState); - }; -} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts new file mode 100644 index 00000000000..89c3f7691eb --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasShapeToolModule.ts @@ -0,0 +1,895 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; +import { shouldPreserveSuspendableShapesSession } from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; +import { + addCoords, + floorCoord, + getPrefixedId, + isDistanceMoreThanMin, + offsetCoord, +} from 'features/controlLayers/konva/util'; +import { selectShapeType } from 'features/controlLayers/store/canvasSettingsSlice'; +import type { + CanvasEntityIdentifier, + CanvasPolygonState, + CanvasRectState, + Coordinate, +} from 'features/controlLayers/store/types'; +import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify'; +import Konva from 'konva'; +import type { KonvaEventObject } from 'konva/lib/Node'; +import type { Logger } from 'roarr'; + +type CanvasShapeToolModuleConfig = { + START_POINT_RADIUS_PX: number; + START_POINT_STROKE_WIDTH_PX: number; + START_POINT_HOVER_RADIUS_DELTA_PX: number; + POLYGON_CLOSE_RADIUS_PX: number; + MIN_FREEHAND_POINT_DISTANCE_PX: number; + MAX_FREEHAND_SEGMENT_LENGTH_PX: number; + FREEHAND_SIMPLIFY_MIN_POINTS: number; + FREEHAND_SIMPLIFY_TOLERANCE: number; + PREVIEW_STROKE_COLOR: string; +}; + +const DEFAULT_CONFIG: CanvasShapeToolModuleConfig = { + START_POINT_RADIUS_PX: 4, + START_POINT_STROKE_WIDTH_PX: 2, + START_POINT_HOVER_RADIUS_DELTA_PX: 2, + POLYGON_CLOSE_RADIUS_PX: 10, + MIN_FREEHAND_POINT_DISTANCE_PX: 1, + MAX_FREEHAND_SEGMENT_LENGTH_PX: 2, + FREEHAND_SIMPLIFY_MIN_POINTS: 200, + FREEHAND_SIMPLIFY_TOLERANCE: 0.6, + PREVIEW_STROKE_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 1 }), +}; + +const SUBTRACT_CURSOR = `url("data:image/svg+xml,${encodeURIComponent( + ` + + + + + + + ` +)}") 12 12, crosshair`; + +const getAxisSign = (value: number, fallback: number): number => { + if (value === 0) { + return fallback === 0 ? 1 : Math.sign(fallback); + } + return Math.sign(value); +}; + +export class CanvasShapeToolModule extends CanvasModuleBase { + readonly type = 'shape_tool'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasToolModule; + readonly manager: CanvasManager; + readonly log: Logger; + + config: CanvasShapeToolModuleConfig = DEFAULT_CONFIG; + subscriptions: Set<() => void> = new Set(); + + private activeEntityIdentifier: CanvasEntityIdentifier | null = null; + private shapeId: string | null = null; + private dragStartPoint: Coordinate | null = null; + private dragCurrentPoint: Coordinate | null = null; + private translatePreviousPointerPoint: Coordinate | null = null; + private freehandPoints: Coordinate[] = []; + private isDrawingFreehand = false; + private polygonPoints: Coordinate[] = []; + private polygonPointer: Coordinate | null = null; + + konva: { + group: Konva.Group; + startPointIndicator: Konva.Circle; + }; + + constructor(parent: CanvasToolModule) { + super(); + this.id = getPrefixedId(this.type); + this.parent = parent; + this.manager = this.parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug('Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + startPointIndicator: new Konva.Circle({ + name: `${this.type}:start_point_indicator`, + listening: false, + fillEnabled: false, + stroke: this.config.PREVIEW_STROKE_COLOR, + visible: false, + perfectDrawEnabled: false, + }), + }; + this.konva.group.add(this.konva.startPointIndicator); + + this.subscriptions.add(this.manager.stateApi.$altKey.listen(this.onModifierChanged)); + this.subscriptions.add(this.manager.stateApi.$ctrlKey.listen(this.onModifierChanged)); + this.subscriptions.add(this.manager.stateApi.$metaKey.listen(this.onModifierChanged)); + this.subscriptions.add(this.manager.stateApi.$shiftKey.listen(this.onModifierChanged)); + this.subscriptions.add( + this.manager.stateApi.createStoreSubscription(selectShapeType, () => { + if (this.hasActiveSession()) { + this.cancel(); + } + this.render(); + }) + ); + } + + hasActiveSession = (): boolean => { + return Boolean( + this.dragStartPoint || this.isDrawingFreehand || this.freehandPoints.length || this.polygonPoints.length + ); + }; + + hasSuspendableSession = (): boolean => { + return Boolean(this.isDrawingFreehand || this.freehandPoints.length || this.polygonPoints.length); + }; + + hasActiveDragSession = (): boolean => { + return Boolean(this.dragStartPoint || this.isDrawingFreehand); + }; + + hasActiveRectOvalDragSession = (): boolean => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + return Boolean(this.dragStartPoint && this.dragCurrentPoint && (shapeType === 'rect' || shapeType === 'oval')); + }; + + hasActivePolygonSession = (): boolean => { + return this.polygonPoints.length > 0; + }; + + isTranslatingDragSession = (): boolean => { + return this.translatePreviousPointerPoint !== null; + }; + + freezePolygonPreview = async () => { + if (!this.hasActivePolygonSession()) { + return; + } + + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + if (!activeEntity || !cursorPos) { + return; + } + + const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + this.polygonPointer = point; + await this.updatePolygonBuffer(); + this.render(); + }; + + onToolChanged = () => { + const tool = this.parent.$tool.get(); + const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession( + tool, + this.parent.$toolBuffer.get(), + this.hasSuspendableSession() + ); + if (tool !== 'rect' && !isTemporaryToolSwitch) { + this.cancel(); + } + }; + + syncCursorStyle = () => { + this.manager.stage.setCursor(this.getCompositeOperation() === 'destination-out' ? SUBTRACT_CURSOR : 'crosshair'); + }; + + render = () => { + const tool = this.parent.$tool.get(); + const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession( + tool, + this.parent.$toolBuffer.get(), + this.hasSuspendableSession() + ); + if (tool !== 'rect' && !isTemporaryToolSwitch) { + this.konva.startPointIndicator.visible(false); + return; + } + + if (tool === 'rect') { + this.syncCursorStyle(); + } + + this.syncStartPointIndicator(); + }; + + cancel = () => { + this.clearActiveBuffer(); + this.resetState(); + this.render(); + }; + + startDragTranslation = () => { + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + if (!activeEntity || !cursorPos || !this.hasActiveRectOvalDragSession()) { + return; + } + + this.translatePreviousPointerPoint = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + }; + + stopDragTranslation = () => { + this.translatePreviousPointerPoint = null; + }; + + onStagePointerDown = async (e: KonvaEventObject) => { + const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + if (!selectedEntity || !cursorPos) { + return; + } + + if (e.evt.button !== 0) { + return; + } + + const shapeType = this.manager.stateApi.getSettings().shapeType; + const point = this.getEntityRelativePoint(cursorPos.relative, selectedEntity.state.position); + + if (shapeType === 'polygon') { + await this.onPolygonPointerDown(point, selectedEntity.entityIdentifier, e.evt.shiftKey); + return; + } + + if (shapeType === 'freehand') { + if (!this.parent.$isPrimaryPointerDown.get()) { + return; + } + + await this.startFreehandSession(point, selectedEntity.entityIdentifier); + return; + } + + if (!this.parent.$isPrimaryPointerDown.get()) { + return; + } + + this.clearActiveBuffer(); + this.resetState(); + this.activeEntityIdentifier = selectedEntity.entityIdentifier; + this.shapeId = getPrefixedId(shapeType); + this.dragStartPoint = point; + this.dragCurrentPoint = point; + await this.updateDragBuffer(); + }; + + onStagePointerMove = async (e: KonvaEventObject) => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + + if (!activeEntity || !cursorPos) { + return; + } + + const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + + if (shapeType === 'polygon') { + if (!this.hasActivePolygonSession()) { + return; + } + this.polygonPointer = this.getPolygonPoint(point, e.evt.shiftKey); + await this.updatePolygonBuffer(); + this.render(); + return; + } + + if (shapeType === 'freehand') { + await this.handleFreehandPointerMove(point); + return; + } + + if (!this.parent.$isPrimaryPointerDown.get() || !this.dragStartPoint) { + return; + } + + if (this.isTranslatingDragSession()) { + await this.translateDragShape(point); + return; + } + + this.dragCurrentPoint = point; + await this.updateDragBuffer(); + }; + + onWindowPointerMove = async () => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + const activeEntity = this.getActiveEntityAdapter(); + const cursorPos = this.parent.$cursorPos.get(); + + if (!activeEntity || !cursorPos || !this.parent.$isPrimaryPointerDown.get()) { + return; + } + + const point = this.getEntityRelativePoint(cursorPos.relative, activeEntity.state.position); + + if (shapeType === 'freehand') { + await this.handleFreehandPointerMove(point); + return; + } + + if ((shapeType !== 'rect' && shapeType !== 'oval') || !this.dragStartPoint) { + return; + } + + if (this.isTranslatingDragSession()) { + await this.translateDragShape(point); + return; + } + + this.dragCurrentPoint = point; + await this.updateDragBuffer(); + }; + + onStagePointerUp = async (_e: KonvaEventObject) => { + const shapeType = this.manager.stateApi.getSettings().shapeType; + + if (shapeType === 'polygon') { + this.render(); + return; + } + + if (shapeType === 'freehand') { + await this.commitFreehand(); + return; + } + + this.finishDragShapeSession(); + }; + + onWindowPointerUp = async () => { + if (this.isDrawingFreehand) { + await this.commitFreehand(); + return; + } + + if (!this.dragStartPoint) { + return; + } + + this.finishDragShapeSession(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + activeEntityIdentifier: this.activeEntityIdentifier, + shapeId: this.shapeId, + dragStartPoint: this.dragStartPoint, + dragCurrentPoint: this.dragCurrentPoint, + translatePreviousPointerPoint: this.translatePreviousPointerPoint, + freehandPoints: this.freehandPoints, + isDrawingFreehand: this.isDrawingFreehand, + polygonPoints: this.polygonPoints, + polygonPointer: this.polygonPointer, + }; + }; + + destroy = () => { + this.log.debug('Destroying module'); + this.subscriptions.forEach((unsubscribe) => unsubscribe()); + this.subscriptions.clear(); + this.konva.group.destroy(); + }; + + private onModifierChanged = () => { + const tool = this.parent.$tool.get(); + const isTemporaryToolSwitch = shouldPreserveSuspendableShapesSession( + tool, + this.parent.$toolBuffer.get(), + this.hasSuspendableSession() + ); + if (tool !== 'rect' && !isTemporaryToolSwitch) { + return; + } + + if (tool === 'rect') { + this.syncCursorStyle(); + } + void this.updateActivePreview(); + this.render(); + }; + + private updateActivePreview = async () => { + if (this.dragStartPoint) { + await this.updateDragBuffer(); + return; + } + + if (this.isDrawingFreehand || this.freehandPoints.length > 0) { + await this.updateFreehandBuffer(); + return; + } + + if (this.hasActivePolygonSession()) { + await this.updatePolygonBuffer(); + } + }; + + private startFreehandSession = async (point: Coordinate, entityIdentifier: CanvasEntityIdentifier) => { + this.clearActiveBuffer(); + this.resetState(); + this.activeEntityIdentifier = entityIdentifier; + this.shapeId = getPrefixedId('polygon'); + this.freehandPoints = [point]; + this.isDrawingFreehand = true; + await this.updateFreehandBuffer(); + }; + + private handleFreehandPointerMove = async (point: Coordinate) => { + if (!this.isDrawingFreehand || !this.parent.$isPrimaryPointerDown.get()) { + return; + } + + const minDistance = this.manager.stage.unscale(this.config.MIN_FREEHAND_POINT_DISTANCE_PX); + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!isDistanceMoreThanMin(point, lastPoint, minDistance)) { + return; + } + + this.appendFreehandPoint(point); + await this.updateFreehandBuffer(); + }; + + private translateDragShape = async (point: Coordinate) => { + if (!this.translatePreviousPointerPoint || !this.dragStartPoint || !this.dragCurrentPoint) { + return; + } + + const dx = point.x - this.translatePreviousPointerPoint.x; + const dy = point.y - this.translatePreviousPointerPoint.y; + + if (dx === 0 && dy === 0) { + return; + } + + this.dragStartPoint = { + x: this.dragStartPoint.x + dx, + y: this.dragStartPoint.y + dy, + }; + this.dragCurrentPoint = { + x: this.dragCurrentPoint.x + dx, + y: this.dragCurrentPoint.y + dy, + }; + this.translatePreviousPointerPoint = point; + + await this.updateDragBuffer(); + }; + + private onPolygonPointerDown = async ( + point: Coordinate, + entityIdentifier: CanvasEntityIdentifier, + shouldSnap: boolean + ) => { + if ( + this.activeEntityIdentifier && + (this.activeEntityIdentifier.id !== entityIdentifier.id || + this.activeEntityIdentifier.type !== entityIdentifier.type) + ) { + this.cancel(); + } + + this.activeEntityIdentifier = entityIdentifier; + this.dragStartPoint = null; + this.dragCurrentPoint = null; + + if (this.polygonPoints.length === 0) { + this.shapeId = getPrefixedId('polygon'); + this.polygonPoints = [point]; + this.polygonPointer = point; + await this.updatePolygonBuffer(); + this.render(); + return; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return; + } + + if (this.polygonPoints.length >= 3 && this.isPointNearStart(point)) { + await this.commitPolygon(); + return; + } + + const polygonPoint = this.getPolygonPoint(point, shouldSnap); + this.polygonPoints = [...this.polygonPoints, polygonPoint]; + this.polygonPointer = polygonPoint; + await this.updatePolygonBuffer(); + this.render(); + }; + + private commitPolygon = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId || this.polygonPoints.length < 3) { + this.cancel(); + return; + } + + const polygonState: CanvasPolygonState = { + id: this.shapeId, + type: 'polygon', + points: this.polygonPoints.flatMap((point) => [point.x, point.y]), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }; + + await activeEntity.bufferRenderer.setBuffer(polygonState); + activeEntity.bufferRenderer.commitBuffer(); + this.resetState(); + this.render(); + }; + + private commitFreehand = async () => { + if (!this.isDrawingFreehand) { + return; + } + + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId) { + this.cancel(); + return; + } + + const simplifiedPoints = this.simplifyFreehandContour(this.freehandPoints); + if (simplifiedPoints.length < 3) { + activeEntity.bufferRenderer.clearBuffer(); + this.resetState(); + this.render(); + return; + } + + const polygonState: CanvasPolygonState = { + id: this.shapeId, + type: 'polygon', + points: simplifiedPoints.flatMap((point) => [point.x, point.y]), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }; + + await activeEntity.bufferRenderer.setBuffer(polygonState); + activeEntity.bufferRenderer.commitBuffer(); + this.resetState(); + this.render(); + }; + + private updateDragBuffer = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.dragStartPoint || !this.dragCurrentPoint || !this.shapeId) { + return; + } + + const shapeType = this.manager.stateApi.getSettings().shapeType; + if (shapeType !== 'rect' && shapeType !== 'oval') { + return; + } + + const rect = this.getDragRect(this.dragStartPoint, this.dragCurrentPoint, { + fromCenter: this.manager.stateApi.$altKey.get(), + constrainSquare: this.manager.stateApi.$shiftKey.get(), + }); + + await activeEntity.bufferRenderer.setBuffer({ + id: this.shapeId, + type: shapeType, + rect, + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }); + }; + + private updatePolygonBuffer = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId || this.polygonPoints.length === 0) { + return; + } + + await activeEntity.bufferRenderer.setBuffer({ + id: this.shapeId, + type: 'polygon', + points: this.polygonPoints.flatMap((point) => [point.x, point.y]), + previewPoint: this.polygonPointer ?? this.polygonPoints.at(-1), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }); + }; + + private updateFreehandBuffer = async () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity || !this.shapeId || this.freehandPoints.length === 0) { + return; + } + + await activeEntity.bufferRenderer.setBuffer({ + id: this.shapeId, + type: 'polygon', + points: this.freehandPoints.flatMap((point) => [point.x, point.y]), + color: this.manager.stateApi.getCurrentColor(), + compositeOperation: this.getCompositeOperation(), + }); + }; + + private syncStartPointIndicator = () => { + const activeEntity = this.getActiveEntityAdapter(); + const startPoint = this.polygonPoints[0]; + if (!activeEntity || !startPoint || this.manager.stateApi.getSettings().shapeType !== 'polygon') { + this.konva.startPointIndicator.visible(false); + return; + } + + const isHoveringStartPoint = this.getIsHoveringStartPoint(startPoint, activeEntity.state.position); + const baseRadius = this.manager.stage.unscale(this.config.START_POINT_RADIUS_PX); + const stagePoint = addCoords(startPoint, activeEntity.state.position); + + this.konva.startPointIndicator.setAttrs({ + x: stagePoint.x, + y: stagePoint.y, + radius: + baseRadius + + (isHoveringStartPoint ? this.manager.stage.unscale(this.config.START_POINT_HOVER_RADIUS_DELTA_PX) : 0), + strokeWidth: this.manager.stage.unscale(this.config.START_POINT_STROKE_WIDTH_PX), + visible: true, + }); + }; + + private getEntityRelativePoint = (point: Coordinate, position: Coordinate): Coordinate => { + return floorCoord(offsetCoord(point, position)); + }; + + private getCompositeOperation = (): CanvasRectState['compositeOperation'] => { + return this.manager.stateApi.$ctrlKey.get() || this.manager.stateApi.$metaKey.get() + ? 'destination-out' + : 'source-over'; + }; + + private getPolygonPoint = (point: Coordinate, shouldSnap: boolean): Coordinate => { + if (!shouldSnap) { + return point; + } + + const lastPoint = this.polygonPoints.at(-1); + if (!lastPoint) { + return point; + } + + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + if (distance === 0) { + return point; + } + + const snapAngle = Math.PI / 4; + const angle = Math.atan2(dy, dx); + const snappedAngle = Math.round(angle / snapAngle) * snapAngle; + + const snappedPoint = { + x: lastPoint.x + Math.cos(snappedAngle) * distance, + y: lastPoint.y + Math.sin(snappedAngle) * distance, + }; + + return this.alignPointToStart(snappedPoint); + }; + + private isPointNearStart = (point: Coordinate): boolean => { + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return false; + } + return Math.hypot(point.x - startPoint.x, point.y - startPoint.y) <= this.getPolygonCloseRadius(); + }; + + private getPolygonCloseRadius = (): number => { + return this.manager.stage.unscale(this.config.POLYGON_CLOSE_RADIUS_PX); + }; + + private getIsHoveringStartPoint = (startPoint: Coordinate, entityPosition: Coordinate): boolean => { + if (this.polygonPoints.length < 3) { + return false; + } + + const pointerPoint = this.parent.$cursorPos.get()?.relative; + if (!pointerPoint) { + return false; + } + + const entityRelativePointerPoint = this.getEntityRelativePoint(pointerPoint, entityPosition); + return ( + Math.hypot(entityRelativePointerPoint.x - startPoint.x, entityRelativePointerPoint.y - startPoint.y) <= + this.getPolygonCloseRadius() + ); + }; + + private alignPointToStart = (point: Coordinate): Coordinate => { + if (this.polygonPoints.length < 2) { + return point; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return point; + } + + const alignThreshold = this.getPolygonCloseRadius(); + const deltaX = Math.abs(point.x - startPoint.x); + const deltaY = Math.abs(point.y - startPoint.y); + const canAlignX = deltaX <= alignThreshold; + const canAlignY = deltaY <= alignThreshold; + + if (!canAlignX && !canAlignY) { + return point; + } + + if (canAlignX && canAlignY) { + if (deltaX <= deltaY) { + return { x: startPoint.x, y: point.y }; + } + return { x: point.x, y: startPoint.y }; + } + + if (canAlignX) { + return { x: startPoint.x, y: point.y }; + } + + return { x: point.x, y: startPoint.y }; + }; + + private appendFreehandPoint = (point: Coordinate) => { + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!lastPoint) { + this.freehandPoints.push(point); + return; + } + + const maxSegmentLength = this.manager.stage.unscale(this.config.MAX_FREEHAND_SEGMENT_LENGTH_PX); + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + + if (distance <= maxSegmentLength) { + this.freehandPoints.push(point); + return; + } + + const steps = Math.ceil(distance / maxSegmentLength); + for (let i = 1; i <= steps; i++) { + const t = i / steps; + this.freehandPoints.push({ + x: lastPoint.x + dx * t, + y: lastPoint.y + dy * t, + }); + } + }; + + private simplifyFreehandContour = (points: Coordinate[]): Coordinate[] => { + if (points.length < this.config.FREEHAND_SIMPLIFY_MIN_POINTS) { + return points; + } + + const simplifiedFlatPoints = simplifyFlatNumbersArray( + points.flatMap((point) => [point.x, point.y]), + { + tolerance: this.config.FREEHAND_SIMPLIFY_TOLERANCE, + highestQuality: true, + } + ); + + if (simplifiedFlatPoints.length < 6) { + return points; + } + + const simplifiedPoints = this.flatNumbersToCoords(simplifiedFlatPoints); + if (simplifiedPoints.length < 3) { + return points; + } + + return simplifiedPoints; + }; + + private flatNumbersToCoords = (points: number[]): Coordinate[] => { + const coords: Coordinate[] = []; + for (let i = 0; i < points.length; i += 2) { + const x = points[i]; + const y = points[i + 1]; + if (x === undefined || y === undefined) { + continue; + } + coords.push({ x, y }); + } + return coords; + }; + + private getDragRect = ( + start: Coordinate, + end: Coordinate, + options: { fromCenter: boolean; constrainSquare: boolean } + ): CanvasRectState['rect'] => { + let dx = end.x - start.x; + let dy = end.y - start.y; + + if (options.constrainSquare) { + const size = Math.max(Math.abs(dx), Math.abs(dy)); + const dxSign = getAxisSign(dx, dy); + const dySign = getAxisSign(dy, dx); + dx = dxSign * size; + dy = dySign * size; + } + + const x1 = options.fromCenter ? start.x - dx : start.x; + const y1 = options.fromCenter ? start.y - dy : start.y; + const x2 = options.fromCenter ? start.x + dx : start.x + dx; + const y2 = options.fromCenter ? start.y + dy : start.y + dy; + + return { + x: Math.min(x1, x2), + y: Math.min(y1, y2), + width: Math.abs(x2 - x1), + height: Math.abs(y2 - y1), + }; + }; + + private getActiveEntityAdapter = () => { + if (!this.activeEntityIdentifier) { + return null; + } + return this.manager.getAdapter(this.activeEntityIdentifier); + }; + + private finishDragShapeSession = () => { + const activeEntity = this.getActiveEntityAdapter(); + if (!activeEntity) { + this.resetState(); + this.render(); + return; + } + + const bufferState = activeEntity.bufferRenderer.state; + if ( + bufferState && + (bufferState.type === 'rect' || bufferState.type === 'oval') && + activeEntity.bufferRenderer.hasBuffer() && + bufferState.rect.width > 0 && + bufferState.rect.height > 0 + ) { + activeEntity.bufferRenderer.commitBuffer(); + } else { + activeEntity.bufferRenderer.clearBuffer(); + } + + this.resetState(); + this.render(); + }; + + private clearActiveBuffer = () => { + this.getActiveEntityAdapter()?.bufferRenderer.clearBuffer(); + }; + + private resetState = () => { + this.activeEntityIdentifier = null; + this.shapeId = null; + this.dragStartPoint = null; + this.dragCurrentPoint = null; + this.translatePreviousPointerPoint = null; + this.freehandPoints = []; + this.isDrawingFreehand = false; + this.polygonPoints = []; + this.polygonPointer = null; + this.konva.startPointIndicator.visible(false); + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts index beca4d14a0a..98e1ab7d5d9 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts @@ -1,5 +1,6 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { CanvasBboxToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBboxToolModule'; import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasBrushToolModule'; import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule'; @@ -7,9 +8,15 @@ import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/ import { CanvasGradientToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasGradientToolModule'; import { CanvasLassoToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasLassoToolModule'; import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule'; -import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule'; +import { CanvasShapeToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasShapeToolModule'; import { CanvasTextToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasTextToolModule'; import { CanvasViewToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasViewToolModule'; +import { + getToolToCancelOnEscape, + shouldPreserveSuspendableShapesSession, + shouldQuickSwitchToColorPickerOnAlt, + shouldTranslateShapeDragOnSpace, +} from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; import { ZOOM_DRAG_CURSOR } from 'features/controlLayers/konva/cursors/zoomDragCursor'; import { calculateNewBrushSizeFromWheelDelta, @@ -64,7 +71,7 @@ export class CanvasToolModule extends CanvasModuleBase { tools: { brush: CanvasBrushToolModule; eraser: CanvasEraserToolModule; - rect: CanvasRectToolModule; + rect: CanvasShapeToolModule; lasso: CanvasLassoToolModule; gradient: CanvasGradientToolModule; colorPicker: CanvasColorPickerToolModule; @@ -124,7 +131,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools = { brush: new CanvasBrushToolModule(this), eraser: new CanvasEraserToolModule(this), - rect: new CanvasRectToolModule(this), + rect: new CanvasShapeToolModule(this), lasso: new CanvasLassoToolModule(this), gradient: new CanvasGradientToolModule(this), colorPicker: new CanvasColorPickerToolModule(this), @@ -141,6 +148,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.konva.group.add(this.tools.brush.konva.group); this.konva.group.add(this.tools.eraser.konva.group); + this.konva.group.add(this.tools.rect.konva.group); this.konva.group.add(this.tools.colorPicker.konva.group); this.konva.group.add(this.tools.text.konva.group); this.konva.group.add(this.tools.bbox.konva.group); @@ -152,17 +160,24 @@ export class CanvasToolModule extends CanvasModuleBase { this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.render)); this.subscriptions.add( this.$tool.listen((tool, previousTool) => { - // Preserve pointer state during temporary view switching so lasso sessions can freeze/resume on space. - const shouldPreservePointerState = + // Preserve pointer state during temporary view switching so lasso and shapes sessions can freeze/resume on + // space. + const shouldPreserveLassoPointerState = this.$toolBuffer.get() === 'lasso' && this.tools.lasso.hasActiveSession() && ((previousTool === 'lasso' && tool === 'view') || (previousTool === 'view' && tool === 'lasso')); + const shouldPreserveShapesPointerState = + this.$toolBuffer.get() === 'rect' && + this.tools.rect.hasSuspendableSession() && + ((previousTool === 'rect' && tool === 'view') || (previousTool === 'view' && tool === 'rect')); + const shouldPreservePointerState = shouldPreserveLassoPointerState || shouldPreserveShapesPointerState; if (!shouldPreservePointerState) { // On tool switch, reset mouse state this.manager.tool.$isPrimaryPointerDown.set(false); } + this.tools.rect.onToolChanged(); this.tools.lasso.onToolChanged(); void this.tools.text.onToolChanged(); this.render(); @@ -239,6 +254,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.brush.render(); this.tools.eraser.render(); + this.tools.rect.render(); this.tools.colorPicker.render(); this.tools.text.render(); this.tools.bbox.render(); @@ -411,9 +427,8 @@ export class CanvasToolModule extends CanvasModuleBase { const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if ( - selectedEntity?.bufferRenderer.state?.type !== 'rect' && - selectedEntity?.bufferRenderer.state?.type !== 'gradient' && - selectedEntity?.bufferRenderer.hasBuffer() + selectedEntity?.bufferRenderer.hasBuffer() && + !this.shouldDeferEnterLeaveCommit(selectedEntity.bufferRenderer.state) ) { selectedEntity.bufferRenderer.commitBuffer(); return; @@ -467,7 +482,7 @@ export class CanvasToolModule extends CanvasModuleBase { } }; - onStagePointerUp = (e: KonvaEventObject) => { + onStagePointerUp = async (e: KonvaEventObject) => { if (e.target !== this.konva.stage) { return; } @@ -490,7 +505,7 @@ export class CanvasToolModule extends CanvasModuleBase { } else if (tool === 'eraser') { this.tools.eraser.onStagePointerUp(e); } else if (tool === 'rect') { - this.tools.rect.onStagePointerUp(e); + await this.tools.rect.onStagePointerUp(e); } else if (tool === 'lasso') { void this.tools.lasso.onStagePointerUp(e); } else if (tool === 'gradient') { @@ -534,6 +549,8 @@ export class CanvasToolModule extends CanvasModuleBase { await this.tools.gradient.onStagePointerMove(e); } else if (tool === 'text') { // Already handled above + } else if (this.isTemporaryShapesToolSwitch()) { + // Preserve in-progress polygon/freehand shapes while temporarily switching to view or color picker. } else { this.manager.stateApi.getSelectedEntityAdapter()?.bufferRenderer.clearBuffer(); } @@ -559,9 +576,8 @@ export class CanvasToolModule extends CanvasModuleBase { if ( selectedEntity && - selectedEntity.bufferRenderer.state?.type !== 'rect' && - selectedEntity.bufferRenderer.state?.type !== 'gradient' && - selectedEntity.bufferRenderer.hasBuffer() + selectedEntity.bufferRenderer.hasBuffer() && + !this.shouldDeferEnterLeaveCommit(selectedEntity.bufferRenderer.state) ) { selectedEntity.bufferRenderer.commitBuffer(); } @@ -604,20 +620,19 @@ export class CanvasToolModule extends CanvasModuleBase { this.render(); }; - /** - * Commit the buffer on window pointer up. - * - * The user may start drawing inside the stage and then release the mouse button outside of the stage. To prevent - * whatever the user was drawing from being lost, or ending up with stale state, we need to commit the buffer - * on window pointer up. - */ - onWindowPointerUp = (_: PointerEvent) => { + onWindowPointerUp = async (_: PointerEvent) => { try { this.$isPrimaryPointerDown.set(false); void this.tools.lasso.onWindowPointerUp(); + await this.tools.rect.onWindowPointerUp(); const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); - if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) { + if ( + selectedEntity && + selectedEntity.bufferRenderer.hasBuffer() && + !this.manager.$isBusy.get() && + !this.shouldSkipWindowPointerUpCommit(selectedEntity.bufferRenderer.state) + ) { selectedEntity.bufferRenderer.commitBuffer(); } } finally { @@ -625,36 +640,38 @@ export class CanvasToolModule extends CanvasModuleBase { } }; - onWindowPointerMove = (e: PointerEvent) => { + onWindowPointerMove = async (e: PointerEvent) => { const target = e.target; if (target instanceof Node && this.manager.stage.container.contains(target)) { return; } - if (this.$tool.get() !== 'lasso') { - return; - } - - if (!this.getCanDraw()) { - return; - } - - if (!this.$isPrimaryPointerDown.get()) { - return; - } - - if (!this.tools.lasso.hasActiveSession()) { - return; - } - try { this.$lastPointerType.set(e.pointerType); + if (!this.getCanDraw()) { + return; + } + + if (!this.$isPrimaryPointerDown.get()) { + return; + } + if (!this.syncCursorPositionsFromWindowEvent(e)) { return; } - this.tools.lasso.onWindowPointerMove(e); + if (this.$tool.get() === 'rect') { + if (!this.tools.rect.hasActiveDragSession()) { + return; + } + await this.tools.rect.onWindowPointerMove(); + } else if (this.$tool.get() === 'lasso') { + if (!this.tools.lasso.hasActiveSession()) { + return; + } + this.tools.lasso.onWindowPointerMove(e); + } } finally { this.render(); } @@ -665,6 +682,8 @@ export class CanvasToolModule extends CanvasModuleBase { * and the color picker tool is still active when you come back. */ onWindowBlur = () => { + this.manager.stateApi.$spaceKey.set(false); + this.tools.rect.stopDragTranslation(); this.revertToolBuffer(); }; @@ -691,9 +710,25 @@ export class CanvasToolModule extends CanvasModuleBase { if (e.key === KEY_ESCAPE) { // Cancel shape drawing on escape e.preventDefault(); - if (this.$tool.get() === 'lasso') { + const tool = this.$tool.get(); + const toolToCancel = getToolToCancelOnEscape( + tool, + this.$toolBuffer.get(), + this.tools.lasso.hasActiveSession(), + this.tools.rect.hasSuspendableSession() + ); + + this.manager.stateApi.$spaceKey.set(false); + this.tools.rect.stopDragTranslation(); + if (toolToCancel === 'rect') { + this.tools.rect.cancel(); + } + if (toolToCancel === 'lasso') { this.tools.lasso.reset(); } + if (toolToCancel && tool !== toolToCancel) { + this.revertToolBuffer(); + } const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if ( selectedEntity && @@ -707,16 +742,40 @@ export class CanvasToolModule extends CanvasModuleBase { } if (isSpaceKey) { - // Select the view tool on space key down e.preventDefault(); e.stopPropagation(); const currentTool = this.$tool.get(); - this.$toolBuffer.set(currentTool); + const shapeType = this.manager.stateApi.getSettings().shapeType; + const hasActiveShapeDragSession = this.tools.rect.hasActiveDragSession(); + const isPrimaryPointerDown = this.$isPrimaryPointerDown.get(); this.manager.stateApi.$spaceKey.set(true); + + if (shouldTranslateShapeDragOnSpace(currentTool, shapeType, hasActiveShapeDragSession, isPrimaryPointerDown)) { + this.tools.rect.startDragTranslation(); + return; + } + + if (currentTool === 'rect' && this.tools.rect.hasActivePolygonSession()) { + void this.tools.rect.freezePolygonPreview(); + } + + // Select the view tool on space key down + this.$toolBuffer.set(currentTool); this.$tool.set('view'); if (currentTool === 'lasso' && this.tools.lasso.hasActiveSession() && this.$isPrimaryPointerDown.get()) { // Start panning immediately if user is already drawing with freehand lasso. this.manager.stage.startDragging(); + } else if ( + currentTool === 'rect' && + this.tools.rect.hasSuspendableSession() && + this.$isPrimaryPointerDown.get() + ) { + // Match lasso: allow an in-progress freehand shapes session to freeze and pan immediately on space. + this.manager.stage.startDragging(); + } else if (currentTool === 'rect' && this.tools.rect.hasActivePolygonSession()) { + // Match polygon lasso: when a polygon session is active, Space should immediately enter panning without + // requiring an extra click on the canvas. + this.manager.stage.startDragging(); } else { this.$cursorPos.set(null); } @@ -724,10 +783,17 @@ export class CanvasToolModule extends CanvasModuleBase { } if (e.key === KEY_ALT) { + const tool = this.$tool.get(); + const shapeType = this.manager.stateApi.getSettings().shapeType; + const hasActiveShapeDragSession = this.tools.rect.hasActiveDragSession(); + if (!shouldQuickSwitchToColorPickerOnAlt(tool, shapeType, hasActiveShapeDragSession)) { + e.preventDefault(); + return; + } // Select the color picker on alt key down e.preventDefault(); e.stopPropagation(); - this.$toolBuffer.set(this.$tool.get()); + this.$toolBuffer.set(tool); this.$tool.set('colorPicker'); } }; @@ -747,11 +813,15 @@ export class CanvasToolModule extends CanvasModuleBase { } if (e.key === KEY_SPACE || e.code === CODE_SPACE) { - // Revert the tool to the previous tool on space key up e.preventDefault(); e.stopPropagation(); - this.revertToolBuffer(); this.manager.stateApi.$spaceKey.set(false); + if (this.tools.rect.isTranslatingDragSession()) { + this.tools.rect.stopDragTranslation(); + return; + } + // Revert the tool to the previous tool on space key up + this.revertToolBuffer(); return; } @@ -806,4 +876,28 @@ export class CanvasToolModule extends CanvasModuleBase { } this.konva.group.destroy(); }; + + private shouldDeferEnterLeaveCommit = (state: AnyObjectState | null) => { + if (!state) { + return false; + } + + if (state.type === 'rect' || state.type === 'oval' || state.type === 'gradient') { + return true; + } + + return state.type === 'polygon'; + }; + + private shouldSkipWindowPointerUpCommit = (state: AnyObjectState | null) => { + return Boolean(state?.type === 'polygon' && state.previewPoint); + }; + + private isTemporaryShapesToolSwitch = () => { + return shouldPreserveSuspendableShapesSession( + this.$tool.get(), + this.$toolBuffer.get(), + this.tools.rect.hasSuspendableSession() + ); + }; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts new file mode 100644 index 00000000000..e4fd800d2c2 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.test.ts @@ -0,0 +1,70 @@ +import { describe, expect, it } from 'vitest'; + +import { + getToolToCancelOnEscape, + shouldPreserveSuspendableShapesSession, + shouldQuickSwitchToColorPickerOnAlt, + shouldTranslateShapeDragOnSpace, +} from './toolHotkeys'; + +describe('tool hotkeys', () => { + it('keeps the color-picker quick-switch available before starting rect and oval drags', () => { + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'rect', false)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'oval', false)).toBe(true); + }); + + it('blocks the color-picker quick-switch while a rect, oval, or freehand drag is active', () => { + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'rect', true)).toBe(false); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'oval', true)).toBe(false); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'freehand', true)).toBe(false); + }); + + it('keeps the color-picker quick-switch for polygon mode and non-shape tools', () => { + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'polygon', false)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('rect', 'polygon', true)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('brush', 'rect', true)).toBe(true); + expect(shouldQuickSwitchToColorPickerOnAlt('lasso', 'polygon', false)).toBe(true); + }); + + it('uses Space to translate active rect and oval drags instead of switching to view', () => { + expect(shouldTranslateShapeDragOnSpace('rect', 'rect', true, true)).toBe(true); + expect(shouldTranslateShapeDragOnSpace('rect', 'oval', true, true)).toBe(true); + }); + + it('does not use Space translation outside active rect and oval drags', () => { + expect(shouldTranslateShapeDragOnSpace('rect', 'rect', false, true)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('rect', 'rect', true, false)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('rect', 'polygon', true, true)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('rect', 'freehand', true, true)).toBe(false); + expect(shouldTranslateShapeDragOnSpace('brush', 'rect', true, true)).toBe(false); + }); + + it('preserves suspendable shapes sessions across temporary view and color-picker switches', () => { + expect(shouldPreserveSuspendableShapesSession('view', 'rect', true)).toBe(true); + expect(shouldPreserveSuspendableShapesSession('colorPicker', 'rect', true)).toBe(true); + expect(shouldPreserveSuspendableShapesSession('rect', 'rect', true)).toBe(true); + }); + + it('does not preserve suspendable shapes sessions for unrelated tool switches', () => { + expect(shouldPreserveSuspendableShapesSession('brush', 'rect', true)).toBe(false); + expect(shouldPreserveSuspendableShapesSession('view', null, true)).toBe(false); + expect(shouldPreserveSuspendableShapesSession('colorPicker', 'rect', false)).toBe(false); + }); + + it('cancels the active drawing tool directly on escape', () => { + expect(getToolToCancelOnEscape('rect', null, false, false)).toBe('rect'); + expect(getToolToCancelOnEscape('lasso', null, false, false)).toBe('lasso'); + }); + + it('cancels preserved drawing sessions while temporarily switched away', () => { + expect(getToolToCancelOnEscape('view', 'lasso', true, false)).toBe('lasso'); + expect(getToolToCancelOnEscape('view', 'rect', false, true)).toBe('rect'); + expect(getToolToCancelOnEscape('colorPicker', 'rect', false, true)).toBe('rect'); + }); + + it('does not cancel unrelated buffered tools on escape', () => { + expect(getToolToCancelOnEscape('view', 'lasso', false, false)).toBeNull(); + expect(getToolToCancelOnEscape('colorPicker', 'lasso', true, false)).toBeNull(); + expect(getToolToCancelOnEscape('view', 'brush', false, true)).toBeNull(); + }); +}); diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts new file mode 100644 index 00000000000..84cbdc93dce --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/toolHotkeys.ts @@ -0,0 +1,65 @@ +import type { Tool } from 'features/controlLayers/store/types'; + +type ShapeType = 'rect' | 'oval' | 'polygon' | 'freehand'; + +export const shouldPreserveSuspendableShapesSession = ( + tool: Tool, + toolBuffer: Tool | null, + hasSuspendableShapeSession: boolean +): boolean => { + if (!hasSuspendableShapeSession || toolBuffer !== 'rect') { + return false; + } + + return tool === 'view' || tool === 'colorPicker' || tool === 'rect'; +}; + +export const shouldQuickSwitchToColorPickerOnAlt = ( + tool: Tool, + shapeType: ShapeType, + hasActiveShapeDragSession: boolean +): boolean => { + if (tool !== 'rect') { + return true; + } + + if (shapeType === 'polygon') { + return true; + } + + return !hasActiveShapeDragSession; +}; + +export const shouldTranslateShapeDragOnSpace = ( + tool: Tool, + shapeType: ShapeType, + hasActiveShapeDragSession: boolean, + isPrimaryPointerDown: boolean +): boolean => { + if (tool !== 'rect' || !hasActiveShapeDragSession || !isPrimaryPointerDown) { + return false; + } + + return shapeType === 'rect' || shapeType === 'oval'; +}; + +export const getToolToCancelOnEscape = ( + tool: Tool, + toolBuffer: Tool | null, + hasActiveLassoSession: boolean, + hasSuspendableShapeSession: boolean +): Tool | null => { + if (tool === 'rect' || tool === 'lasso') { + return tool; + } + + if (tool === 'view' && toolBuffer === 'lasso' && hasActiveLassoSession) { + return 'lasso'; + } + + if ((tool === 'view' || tool === 'colorPicker') && toolBuffer === 'rect' && hasSuspendableShapeSession) { + return 'rect'; + } + + return null; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts index 202b70e142d..509aefdaaf2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts @@ -14,6 +14,7 @@ export type TransformSmoothingMode = z.infer; const zGradientType = z.enum(['linear', 'radial']); const zLassoMode = z.enum(['freehand', 'polygon']); +const zShapeType = z.enum(['rect', 'oval', 'polygon', 'freehand']); const zCanvasSettingsState = z.object({ /** @@ -115,6 +116,10 @@ const zCanvasSettingsState = z.object({ * The gradient tool type. */ gradientType: zGradientType.default('linear'), + /** + * The shape tool type. + */ + shapeType: zShapeType.default('rect'), /** * Whether the gradient tool clips to the drag gesture. */ @@ -152,6 +157,7 @@ const getInitialState = (): CanvasSettingsState => ({ transformSmoothingEnabled: false, transformSmoothingMode: 'bicubic', gradientType: 'linear', + shapeType: 'rect', gradientClipEnabled: true, lassoMode: 'freehand', }); @@ -248,6 +254,9 @@ const slice = createSlice({ settingsGradientTypeChanged: (state, action: PayloadAction) => { state.gradientType = action.payload; }, + settingsShapeTypeChanged: (state, action: PayloadAction) => { + state.shapeType = action.payload; + }, settingsGradientClipToggled: (state) => { state.gradientClipEnabled = !state.gradientClipEnabled; }, @@ -284,6 +293,7 @@ export const { settingsStagingAreaAutoSwitchChanged, settingsFillColorPickerPinnedSet, settingsGradientTypeChanged, + settingsShapeTypeChanged, settingsGradientClipToggled, settingsLassoModeChanged, } = slice.actions; @@ -326,5 +336,6 @@ export const selectTransformSmoothingEnabled = createCanvasSettingsSelector( ); export const selectTransformSmoothingMode = createCanvasSettingsSelector((settings) => settings.transformSmoothingMode); export const selectGradientType = createCanvasSettingsSelector((settings) => settings.gradientType); +export const selectShapeType = createCanvasSettingsSelector((settings) => settings.shapeType); export const selectGradientClipEnabled = createCanvasSettingsSelector((settings) => settings.gradientClipEnabled); export const selectLassoMode = createCanvasSettingsSelector((settings) => settings.lassoMode); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 9e639c8e7af..a18a6ed308f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -70,7 +70,7 @@ import type { EntityLassoAddedPayload, EntityMovedToPayload, EntityRasterizedPayload, - EntityRectAddedPayload, + EntityShapeAddedPayload, IPMethodV2, T2IAdapterConfig, ZImageControlConfig, @@ -1568,8 +1568,8 @@ const slice = createSlice({ points: eraserLine.type === 'eraser_line' ? simplifyFlatNumbersArray(eraserLine.points) : eraserLine.points, }); }, - entityRectAdded: (state, action: PayloadAction) => { - const { entityIdentifier, rect } = action.payload; + entityShapeAdded: (state, action: PayloadAction) => { + const { entityIdentifier, shape } = action.payload; const entity = selectEntity(state, entityIdentifier); if (!entity) { return; @@ -1577,7 +1577,7 @@ const slice = createSlice({ // TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not // re-render it (reference equality check). I don't like this behaviour. - entity.objects.push({ ...rect }); + entity.objects.push({ ...shape }); }, entityLassoAdded: (state, action: PayloadAction) => { const { entityIdentifier, lasso } = action.payload; @@ -1910,7 +1910,7 @@ export const { entityRasterized, entityBrushLineAdded, entityEraserLineAdded, - entityRectAdded, + entityShapeAdded, entityLassoAdded, entityGradientAdded, // Raster layer adjustments @@ -2046,7 +2046,7 @@ export const canvasSliceConfig: SliceConfig = { const doNotGroupMatcher = isAnyOf( entityBrushLineAdded, entityEraserLineAdded, - entityRectAdded, + entityShapeAdded, entityLassoAdded, entityGradientAdded ); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 7a7ebeade71..cbeccdfa930 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -260,6 +260,7 @@ const zCanvasRectState = z.object({ type: z.literal('rect'), rect: zRect, color: zRgbaColor, + compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'), }); export type CanvasRectState = z.infer; @@ -277,6 +278,28 @@ const zCanvasLassoState = z.object({ }); export type CanvasLassoState = z.infer; +const zCanvasOvalState = z.object({ + id: zId, + type: z.literal('oval'), + rect: zRect, + color: zRgbaColor, + compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'), +}); +export type CanvasOvalState = z.infer; + +const zCanvasPolygonState = z.object({ + id: zId, + type: z.literal('polygon'), + points: zPoints, + color: zRgbaColor, + compositeOperation: z.enum(['source-over', 'destination-out']).default('source-over'), + previewPoint: zCoordinate.optional(), +}); +export type CanvasPolygonState = z.infer; + +const zCanvasShapeState = z.union([zCanvasRectState, zCanvasOvalState, zCanvasPolygonState]); +type CanvasShapeState = z.infer; + // Gradient state includes clip metadata so the tool can optionally clip to drag gesture. const zCanvasLinearGradientState = z.object({ id: zId, @@ -325,7 +348,7 @@ const zCanvasObjectState = z.union([ zCanvasImageState, zCanvasBrushLineState, zCanvasEraserLineState, - zCanvasRectState, + zCanvasShapeState, zCanvasLassoState, zCanvasBrushLineWithPressureState, zCanvasEraserLineWithPressureState, @@ -1020,8 +1043,8 @@ export type EntityBrushLineAddedPayload = EntityIdentifierPayload<{ export type EntityEraserLineAddedPayload = EntityIdentifierPayload<{ eraserLine: CanvasEraserLineState | CanvasEraserLineWithPressureState; }>; -export type EntityRectAddedPayload = EntityIdentifierPayload<{ rect: CanvasRectState }>; export type EntityLassoAddedPayload = EntityIdentifierPayload<{ lasso: CanvasLassoState }>; +export type EntityShapeAddedPayload = EntityIdentifierPayload<{ shape: CanvasShapeState }>; export type EntityGradientAddedPayload = EntityIdentifierPayload<{ gradient: CanvasGradientState }>; export type EntityRasterizedPayload = EntityIdentifierPayload<{ imageObject: CanvasImageState; diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx index 024e92328f9..d486bc1137f 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewCanvasHeaderActions.tsx @@ -3,7 +3,7 @@ import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import type { IDockviewHeaderActionsProps } from 'dockview'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; -import { selectLassoMode } from 'features/controlLayers/store/canvasSettingsSlice'; +import { selectLassoMode, selectShapeType } from 'features/controlLayers/store/canvasSettingsSlice'; import { selectBbox } from 'features/controlLayers/store/selectors'; import type { Tool } from 'features/controlLayers/store/types'; import { IS_MAC_OS } from 'features/system/components/HotkeysModal/useHotkeyData'; @@ -16,6 +16,7 @@ import { WORKSPACE_PANEL_ID } from './shared'; const $fallbackTool = atom('move'); const $fallbackToolBuffer = atom(null); +const $fallbackPrimaryPointerDown = atom(false); const $fallbackTextSession = atom(null); type CanvasToolModifierHintKey = ReturnType[number]['keys'][number]; @@ -45,10 +46,12 @@ export const DockviewCanvasHeaderActions = memo((props: IDockviewHeaderActionsPr const { t } = useTranslation(); const canvasManager = useCanvasManagerSafe(); const lassoMode = useAppSelector(selectLassoMode); + const shapeType = useAppSelector(selectShapeType); const bboxAspectRatioLocked = useAppSelector((state) => selectBbox(state).aspectRatio.isLocked); const tool = useStore(canvasManager?.tool.$tool ?? $fallbackTool); const toolBuffer = useStore(canvasManager?.tool.$toolBuffer ?? $fallbackToolBuffer); + const isPrimaryPointerDown = useStore(canvasManager?.tool.$isPrimaryPointerDown ?? $fallbackPrimaryPointerDown); const textSession = useStore(canvasManager?.tool.tools.text.$session ?? $fallbackTextSession); const effectiveTool = useMemo(() => { @@ -66,10 +69,21 @@ export const DockviewCanvasHeaderActions = memo((props: IDockviewHeaderActionsPr return getCanvasToolModifierHints({ tool: effectiveTool, lassoMode, + shapeType, bboxAspectRatioLocked, hasActiveTextSession: Boolean(textSession), + isPrimaryPointerDown, }); - }, [bboxAspectRatioLocked, canvasManager, effectiveTool, lassoMode, props.activePanel?.id, textSession]); + }, [ + bboxAspectRatioLocked, + canvasManager, + effectiveTool, + isPrimaryPointerDown, + lassoMode, + props.activePanel?.id, + shapeType, + textSession, + ]); if (hints.length === 0) { return null; diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts index 7c599efad23..244634fd381 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts +++ b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.test.ts @@ -2,88 +2,116 @@ import { describe, expect, it } from 'vitest'; import { getCanvasToolModifierHintIds } from './canvasToolModifierHints'; +const buildArgs = (overrides: Partial[0]> = {}) => ({ + tool: 'brush' as const, + lassoMode: 'freehand' as const, + shapeType: 'rect' as const, + bboxAspectRatioLocked: false, + hasActiveTextSession: false, + isPrimaryPointerDown: false, + ...overrides, +}); + describe('getCanvasToolModifierHintIds', () => { it('returns brush hints in priority order', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'brush', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['shiftStraightLine', 'modWheelResizeBrush', 'spacePan', 'altPickColor']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'brush' }))).toEqual([ + 'shiftStraightLine', + 'modWheelResizeBrush', + 'spacePan', + 'altPickColor', + ]); }); it('omits alt color-picker hint for eraser', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'eraser', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['shiftStraightLine', 'modWheelResizeEraser', 'spacePan']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'eraser' }))).toEqual([ + 'shiftStraightLine', + 'modWheelResizeEraser', + 'spacePan', + ]); }); it('adds polygon snapping for polygon lasso', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'lasso', - lassoMode: 'polygon', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['modSubtractMask', 'shiftSnap45Degrees', 'spacePan']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'lasso', lassoMode: 'polygon' }))).toEqual([ + 'modErase', + 'shiftSnap45Degrees', + 'spacePan', + ]); }); it('omits polygon snapping for freehand lasso', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'lasso', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['modSubtractMask', 'spacePan']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'lasso' }))).toEqual(['modErase', 'spacePan']); }); it('switches the bbox aspect-ratio hint based on lock state', () => { - expect( - getCanvasToolModifierHintIds({ - tool: 'bbox', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['shiftLockAspectRatio', 'altScaleFromCenter', 'modFineGrid']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'bbox' }))).toEqual([ + 'shiftLockAspectRatio', + 'altScaleFromCenter', + 'modFineGrid', + ]); - expect( - getCanvasToolModifierHintIds({ - tool: 'bbox', - lassoMode: 'freehand', - bboxAspectRatioLocked: true, - hasActiveTextSession: false, - }) - ).toEqual(['shiftUnlockAspectRatio', 'altScaleFromCenter', 'modFineGrid']); + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'bbox', bboxAspectRatioLocked: true }))).toEqual([ + 'shiftUnlockAspectRatio', + 'altScaleFromCenter', + 'modFineGrid', + ]); }); it('only shows text-session hints when a text session is active', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'text', hasActiveTextSession: true }))).toEqual([ + 'enterCommitText', + 'shiftEnterNewLine', + 'escCancelText', + 'modDragText', + 'shiftSnapRotation', + ]); + + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'text' }))).toEqual(['spacePan', 'altPickColor']); + }); + + it('shows idle rect and oval shapes hints', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'rect' }))).toEqual([ + 'modErase', + 'shiftLockAspectRatio', + 'spacePan', + 'altPickColor', + ]); + + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'oval' }))).toEqual([ + 'modErase', + 'shiftLockAspectRatio', + 'spacePan', + 'altPickColor', + ]); + }); + + it('shows active rect and oval drag hints', () => { + expect( + getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'rect', isPrimaryPointerDown: true })) + ).toEqual(['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']); + expect( - getCanvasToolModifierHintIds({ - tool: 'text', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: true, - }) - ).toEqual(['enterCommitText', 'shiftEnterNewLine', 'escCancelText', 'modDragText', 'shiftSnapRotation']); + getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'oval', isPrimaryPointerDown: true })) + ).toEqual(['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape']); + }); + + it('shows polygon shape hints', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'polygon' }))).toEqual([ + 'modErase', + 'shiftSnap45Degrees', + 'spacePan', + 'altPickColor', + ]); + }); + + it('omits alt color-picker hint during an active freehand stroke', () => { + expect(getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'freehand' }))).toEqual([ + 'modErase', + 'spacePan', + 'altPickColor', + ]); expect( - getCanvasToolModifierHintIds({ - tool: 'text', - lassoMode: 'freehand', - bboxAspectRatioLocked: false, - hasActiveTextSession: false, - }) - ).toEqual(['spacePan', 'altPickColor']); + getCanvasToolModifierHintIds(buildArgs({ tool: 'rect', shapeType: 'freehand', isPrimaryPointerDown: true })) + ).toEqual(['modErase', 'spacePan']); }); }); diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts index a23682e2072..3543ac4358d 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts +++ b/invokeai/frontend/web/src/features/ui/layouts/canvasToolModifierHints.ts @@ -1,14 +1,20 @@ +import { + shouldQuickSwitchToColorPickerOnAlt, + shouldTranslateShapeDragOnSpace, +} from 'features/controlLayers/konva/CanvasTool/toolHotkeys'; import type { Tool } from 'features/controlLayers/store/types'; +type ShapeType = 'rect' | 'oval' | 'polygon' | 'freehand'; type CanvasToolModifierHintKey = 'mod' | 'shift' | 'alt' | 'space' | 'wheel' | 'arrows' | 'enter' | 'esc'; type CanvasToolModifierHintId = | 'spacePan' + | 'spaceMoveShape' | 'altPickColor' | 'shiftStraightLine' | 'modWheelResizeBrush' | 'modWheelResizeEraser' - | 'modSubtractMask' + | 'modErase' | 'shiftSnap45Degrees' | 'shiftLockAspectRatio' | 'shiftUnlockAspectRatio' @@ -35,6 +41,11 @@ const HINTS: Record = { keys: ['space'], labelKey: 'controlLayers.modifierHints.labels.pan', }, + spaceMoveShape: { + id: 'spaceMoveShape', + keys: ['space'], + labelKey: 'controlLayers.modifierHints.labels.moveShape', + }, altPickColor: { id: 'altPickColor', keys: ['alt'], @@ -55,10 +66,10 @@ const HINTS: Record = { keys: ['mod', 'wheel'], labelKey: 'controlLayers.modifierHints.labels.resizeEraser', }, - modSubtractMask: { - id: 'modSubtractMask', + modErase: { + id: 'modErase', keys: ['mod'], - labelKey: 'controlLayers.modifierHints.labels.subtractMask', + labelKey: 'controlLayers.modifierHints.labels.erase', }, shiftSnap45Degrees: { id: 'shiftSnap45Degrees', @@ -120,8 +131,10 @@ const HINTS: Record = { type GetCanvasToolModifierHintsArg = { tool: Tool; lassoMode: 'freehand' | 'polygon'; + shapeType: ShapeType; bboxAspectRatioLocked: boolean; hasActiveTextSession: boolean; + isPrimaryPointerDown: boolean; }; const mapHintIdsToHints = (hintIds: readonly CanvasToolModifierHintId[]): CanvasToolModifierHint[] => @@ -130,8 +143,10 @@ const mapHintIdsToHints = (hintIds: readonly CanvasToolModifierHintId[]): Canvas export const getCanvasToolModifierHintIds = ({ tool, lassoMode, + shapeType, bboxAspectRatioLocked, hasActiveTextSession, + isPrimaryPointerDown, }: GetCanvasToolModifierHintsArg): CanvasToolModifierHintId[] => { // Resolver map: each tool returns the relevant hint ids based on the provided args. const TOOL_HINT_RESOLVERS: Record< @@ -141,7 +156,7 @@ export const getCanvasToolModifierHintIds = ({ brush: () => ['shiftStraightLine', 'modWheelResizeBrush', ...SHARED_HINT_IDS], eraser: () => ['shiftStraightLine', 'modWheelResizeEraser', 'spacePan'], lasso: ({ lassoMode: lm }) => - lm === 'polygon' ? ['modSubtractMask', 'shiftSnap45Degrees', 'spacePan'] : ['modSubtractMask', 'spacePan'], + lm === 'polygon' ? ['modErase', 'shiftSnap45Degrees', 'spacePan'] : ['modErase', 'spacePan'], bbox: ({ bboxAspectRatioLocked: locked }) => [ locked ? 'shiftUnlockAspectRatio' : 'shiftLockAspectRatio', 'altScaleFromCenter', @@ -155,7 +170,21 @@ export const getCanvasToolModifierHintIds = ({ view: () => ['altPickColor'], colorPicker: () => ['spacePan'], gradient: () => [...SHARED_HINT_IDS], - rect: () => [...SHARED_HINT_IDS], + rect: ({ shapeType: st, isPrimaryPointerDown: pointerDown }) => { + if (st === 'polygon') { + return ['modErase', 'shiftSnap45Degrees', 'spacePan', 'altPickColor']; + } + + if (st === 'freehand') { + return shouldQuickSwitchToColorPickerOnAlt('rect', st, pointerDown) + ? ['modErase', 'spacePan', 'altPickColor'] + : ['modErase', 'spacePan']; + } + + return shouldTranslateShapeDragOnSpace('rect', st, pointerDown, pointerDown) + ? ['modErase', 'shiftLockAspectRatio', 'altScaleFromCenter', 'spaceMoveShape'] + : ['modErase', 'shiftLockAspectRatio', 'spacePan', 'altPickColor']; + }, }; const resolver = TOOL_HINT_RESOLVERS[tool]; @@ -163,7 +192,9 @@ export const getCanvasToolModifierHintIds = ({ if (!resolver) { return []; } - return Array.from(resolver({ tool, lassoMode, bboxAspectRatioLocked, hasActiveTextSession })); + return Array.from( + resolver({ tool, lassoMode, shapeType, bboxAspectRatioLocked, hasActiveTextSession, isPrimaryPointerDown }) + ); }; export const getCanvasToolModifierHints = (args: GetCanvasToolModifierHintsArg): CanvasToolModifierHint[] => From af5992480cefdb12671e2c0f658560d651d0c12a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ufuk=20Sarp=20Sel=C3=A7ok?= Date: Thu, 14 May 2026 16:08:11 +0300 Subject: [PATCH 7/7] Remove optimized image-to-image toggle from UI for Z-Image (#9114) Co-authored-by: Alexander Eichhorn Co-authored-by: Lincoln Stein --- invokeai/frontend/web/src/features/modelManagerV2/models.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index 8f0e31ef5cd..63179db844a 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -276,7 +276,7 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record = { unknown: 'Unknown', }; -export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3', 'z-image']; +export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3']; export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2', 'qwen-image'];