Skip to content

Commit c0fcc95

Browse files
Haoran Jiangchenhuacai
authored andcommitted
LoongArch: BPF: Fix the tailcall hierarchy
In specific use cases combining tailcalls and BPF-to-BPF calls, MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt back-propagation from callee to caller. This patch fixes this tailcall issue caused by abusing the tailcall in bpf2bpf feature on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy". Push tail_call_cnt_ptr and tail_call_cnt into the stack, tail_call_cnt_ptr is passed between tailcall and bpf2bpf, uses tail_call_cnt_ptr to increment tail_call_cnt. Fixes: bb035ef ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls") Reviewed-by: Geliang Tang <geliang@kernel.org> Reviewed-by: Hengqi Chen <hengqi.chen@gmail.com> Signed-off-by: Haoran Jiang <jianghaoran@kylinos.cn> Signed-off-by: Huacai Chen <chenhuacai@loongson.cn>
1 parent cd39d9e commit c0fcc95

1 file changed

Lines changed: 107 additions & 48 deletions

File tree

arch/loongarch/net/bpf_jit.c

Lines changed: 107 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
#define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4)
1818

1919
#define REG_TCC LOONGARCH_GPR_A6
20-
#define TCC_SAVED LOONGARCH_GPR_S5
21-
22-
#define SAVE_RA BIT(0)
23-
#define SAVE_TCC BIT(1)
20+
#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
2421

2522
static const int regmap[] = {
2623
/* return value from in-kernel function, and exit value for eBPF program */
@@ -42,32 +39,57 @@ static const int regmap[] = {
4239
[BPF_REG_AX] = LOONGARCH_GPR_T0,
4340
};
4441

45-
static void mark_call(struct jit_ctx *ctx)
42+
static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
4643
{
47-
ctx->flags |= SAVE_RA;
48-
}
44+
const struct bpf_prog *prog = ctx->prog;
45+
const bool is_main_prog = !bpf_is_subprog(prog);
4946

50-
static void mark_tail_call(struct jit_ctx *ctx)
51-
{
52-
ctx->flags |= SAVE_TCC;
53-
}
47+
if (is_main_prog) {
48+
/*
49+
* LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT
50+
* if (REG_TCC > T3 )
51+
* std REG_TCC -> LOONGARCH_GPR_SP + store_offset
52+
* else
53+
* std REG_TCC -> LOONGARCH_GPR_SP + store_offset
54+
* REG_TCC = LOONGARCH_GPR_SP + store_offset
55+
*
56+
* std REG_TCC -> LOONGARCH_GPR_SP + store_offset
57+
*
58+
* The purpose of this code is to first push the TCC into stack,
59+
* and then push the address of TCC into stack.
60+
* In cases where bpf2bpf and tailcall are used in combination,
61+
* the value in REG_TCC may be a count or an address,
62+
* these two cases need to be judged and handled separately.
63+
*/
64+
emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
65+
*store_offset -= sizeof(long);
5466

55-
static bool seen_call(struct jit_ctx *ctx)
56-
{
57-
return (ctx->flags & SAVE_RA);
58-
}
67+
emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
5968

60-
static bool seen_tail_call(struct jit_ctx *ctx)
61-
{
62-
return (ctx->flags & SAVE_TCC);
63-
}
69+
/*
70+
* If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count,
71+
* push tcc into stack
72+
*/
73+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
6474

65-
static u8 tail_call_reg(struct jit_ctx *ctx)
66-
{
67-
if (seen_call(ctx))
68-
return TCC_SAVED;
75+
/* Push the address of TCC into the REG_TCC */
76+
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
6977

70-
return REG_TCC;
78+
emit_uncond_jmp(ctx, 2);
79+
80+
/*
81+
* If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address,
82+
* push tcc_ptr into stack
83+
*/
84+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
85+
} else {
86+
*store_offset -= sizeof(long);
87+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
88+
}
89+
90+
/* Push tcc_ptr into stack */
91+
*store_offset -= sizeof(long);
92+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
7193
}
7294

7395
/*
@@ -90,6 +112,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
90112
* | $s4 |
91113
* +-------------------------+
92114
* | $s5 |
115+
* +-------------------------+
116+
* | tcc |
117+
* +-------------------------+
118+
* | tcc_ptr |
93119
* +-------------------------+ <--BPF_REG_FP
94120
* | prog->aux->stack_depth |
95121
* | (optional) |
@@ -99,12 +125,17 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
99125
static void build_prologue(struct jit_ctx *ctx)
100126
{
101127
int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
128+
const struct bpf_prog *prog = ctx->prog;
129+
const bool is_main_prog = !bpf_is_subprog(prog);
102130

103131
bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
104132

105-
/* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
133+
/* To store ra, fp, s0, s1, s2, s3, s4, s5 */
106134
stack_adjust += sizeof(long) * 8;
107135

136+
/* To store tcc and tcc_ptr */
137+
stack_adjust += sizeof(long) * 2;
138+
108139
stack_adjust = round_up(stack_adjust, 16);
109140
stack_adjust += bpf_stack_adjust;
110141

@@ -113,11 +144,12 @@ static void build_prologue(struct jit_ctx *ctx)
113144
emit_insn(ctx, nop);
114145

115146
/*
116-
* First instruction initializes the tail call count (TCC).
117-
* On tail call we skip this instruction, and the TCC is
118-
* passed in REG_TCC from the caller.
147+
* First instruction initializes the tail call count (TCC)
148+
* register to zero. On tail call we skip this instruction,
149+
* and the TCC is passed in REG_TCC from the caller.
119150
*/
120-
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
151+
if (is_main_prog)
152+
emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
121153

122154
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
123155

@@ -145,20 +177,13 @@ static void build_prologue(struct jit_ctx *ctx)
145177
store_offset -= sizeof(long);
146178
emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
147179

180+
prepare_bpf_tail_call_cnt(ctx, &store_offset);
181+
148182
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
149183

150184
if (bpf_stack_adjust)
151185
emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
152186

153-
/*
154-
* Program contains calls and tail calls, so REG_TCC need
155-
* to be saved across calls.
156-
*/
157-
if (seen_tail_call(ctx) && seen_call(ctx))
158-
move_reg(ctx, TCC_SAVED, REG_TCC);
159-
else
160-
emit_insn(ctx, nop);
161-
162187
ctx->stack_size = stack_adjust;
163188
}
164189

@@ -191,6 +216,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
191216
load_offset -= sizeof(long);
192217
emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
193218

219+
/*
220+
* When push into the stack, follow the order of tcc then tcc_ptr.
221+
* When pop from the stack, first pop tcc_ptr then followed by tcc.
222+
*/
223+
load_offset -= 2 * sizeof(long);
224+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
225+
226+
load_offset += sizeof(long);
227+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
228+
194229
emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
195230

196231
if (!is_tail_call) {
@@ -203,7 +238,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
203238
* Call the next bpf prog and skip the first instruction
204239
* of TCC initialization.
205240
*/
206-
emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1);
241+
emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6);
207242
}
208243
}
209244

@@ -225,7 +260,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
225260
static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
226261
{
227262
int off, tc_ninsn = 0;
228-
u8 tcc = tail_call_reg(ctx);
263+
int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
229264
u8 a1 = LOONGARCH_GPR_A1;
230265
u8 a2 = LOONGARCH_GPR_A2;
231266
u8 t1 = LOONGARCH_GPR_T1;
@@ -252,11 +287,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
252287
goto toofar;
253288

254289
/*
255-
* if (--TCC < 0)
256-
* goto out;
290+
* if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
291+
* goto out;
257292
*/
258-
emit_insn(ctx, addid, REG_TCC, tcc, -1);
259-
if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
293+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
294+
emit_insn(ctx, ldd, t3, REG_TCC, 0);
295+
emit_insn(ctx, addid, t3, t3, 1);
296+
emit_insn(ctx, std, t3, REG_TCC, 0);
297+
emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
298+
if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
260299
goto toofar;
261300

262301
/*
@@ -467,7 +506,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
467506
u64 func_addr;
468507
bool func_addr_fixed, sign_extend;
469508
int i = insn - ctx->prog->insnsi;
470-
int ret, jmp_offset;
509+
int ret, jmp_offset, tcc_ptr_off;
471510
const u8 code = insn->code;
472511
const u8 cond = BPF_OP(code);
473512
const u8 t1 = LOONGARCH_GPR_T1;
@@ -903,12 +942,16 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
903942

904943
/* function call */
905944
case BPF_JMP | BPF_CALL:
906-
mark_call(ctx);
907945
ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
908946
&func_addr, &func_addr_fixed);
909947
if (ret < 0)
910948
return ret;
911949

950+
if (insn->src_reg == BPF_PSEUDO_CALL) {
951+
tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
952+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
953+
}
954+
912955
move_addr(ctx, t1, func_addr);
913956
emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
914957

@@ -919,7 +962,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
919962

920963
/* tail call */
921964
case BPF_JMP | BPF_TAIL_CALL:
922-
mark_tail_call(ctx);
923965
if (emit_bpf_tail_call(ctx, i) < 0)
924966
return -EINVAL;
925967
break;
@@ -1412,7 +1454,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
14121454
{
14131455
int i, ret, save_ret;
14141456
int stack_size = 0, nargs = 0;
1415-
int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off;
1457+
int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off;
14161458
bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
14171459
void *orig_call = func_addr;
14181460
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
@@ -1447,6 +1489,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
14471489
*
14481490
* FP - sreg_off [ callee saved reg ]
14491491
*
1492+
* FP - tcc_ptr_off [ tail_call_cnt_ptr ]
14501493
*/
14511494

14521495
if (m->nr_args > LOONGARCH_MAX_REG_ARGS)
@@ -1489,6 +1532,12 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
14891532
stack_size += 8;
14901533
sreg_off = stack_size;
14911534

1535+
/* Room of trampoline frame to store tail_call_cnt_ptr */
1536+
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
1537+
stack_size += 8;
1538+
tcc_ptr_off = stack_size;
1539+
}
1540+
14921541
stack_size = round_up(stack_size, 16);
14931542

14941543
if (is_struct_ops) {
@@ -1519,6 +1568,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
15191568
emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size);
15201569
}
15211570

1571+
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
1572+
emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
1573+
15221574
/* callee saved register S1 to pass start time */
15231575
emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
15241576

@@ -1565,6 +1617,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
15651617

15661618
if (flags & BPF_TRAMP_F_CALL_ORIG) {
15671619
restore_args(ctx, m->nr_args, args_off);
1620+
1621+
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
1622+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
1623+
15681624
ret = emit_call(ctx, (const u64)orig_call);
15691625
if (ret)
15701626
goto out;
@@ -1605,6 +1661,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
16051661

16061662
emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
16071663

1664+
if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
1665+
emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
1666+
16081667
if (is_struct_ops) {
16091668
/* trampoline called directly */
16101669
emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, stack_size - 8);

0 commit comments

Comments
 (0)