task_work: only grab task signal lock when needed

If JOBCTL_TASK_WORK is already set on the targeted task, then we need
not go through {lock,unlock}_task_sighand() to set it again and queue
a signal wakeup. This is safe as we're checking it _after_ adding the
new task_work with cmpxchg().

The ordering is as follows:

task_work_add()				get_signal()
--------------------------------------------------------------
STORE(task->task_works, new_work);	STORE(task->jobctl);
mb();					mb();
LOAD(task->jobctl);			LOAD(task->task_works);

This speeds up TWA_SIGNAL handling quite a bit, which is important now
that io_uring is relying on it for all task_work deliveries.

Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Jann Horn <jannh@google.com>
Acked-by: Oleg Nesterov <oleg@redhat.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/kernel/signal.c b/kernel/signal.c
index 6f16f7c..42b67d2 100644
--- a/kernel/signal.c
+++ b/kernel/signal.c
@@ -2541,7 +2541,21 @@
 
 relock:
 	spin_lock_irq(&sighand->siglock);
-	current->jobctl &= ~JOBCTL_TASK_WORK;
+	/*
+	 * Make sure we can safely read ->jobctl() in task_work add. As Oleg
+	 * states:
+	 *
+	 * It pairs with mb (implied by cmpxchg) before READ_ONCE. So we
+	 * roughly have
+	 *
+	 *	task_work_add:				get_signal:
+	 *	STORE(task->task_works, new_work);	STORE(task->jobctl);
+	 *	mb();					mb();
+	 *	LOAD(task->jobctl);			LOAD(task->task_works);
+	 *
+	 * and we can rely on STORE-MB-LOAD [ in task_work_add].
+	 */
+	smp_store_mb(current->jobctl, current->jobctl & ~JOBCTL_TASK_WORK);
 	if (unlikely(current->task_works)) {
 		spin_unlock_irq(&sighand->siglock);
 		task_work_run();
diff --git a/kernel/task_work.c b/kernel/task_work.c
index 5c0848c..613b2d6 100644
--- a/kernel/task_work.c
+++ b/kernel/task_work.c
@@ -42,7 +42,13 @@
 		set_notify_resume(task);
 		break;
 	case TWA_SIGNAL:
-		if (lock_task_sighand(task, &flags)) {
+		/*
+		 * Only grab the sighand lock if we don't already have some
+		 * task_work pending. This pairs with the smp_store_mb()
+		 * in get_signal(), see comment there.
+		 */
+		if (!(READ_ONCE(task->jobctl) & JOBCTL_TASK_WORK) &&
+		    lock_task_sighand(task, &flags)) {
 			task->jobctl |= JOBCTL_TASK_WORK;
 			signal_wake_up(task, 0);
 			unlock_task_sighand(task, &flags);