mempool: Implement populate none policy
authorMathieu Desnoyers <mathieu.desnoyers@efficios.com>
Thu, 14 Mar 2024 01:50:56 +0000 (21:50 -0400)
committerMathieu Desnoyers <mathieu.desnoyers@efficios.com>
Sat, 16 Mar 2024 20:41:58 +0000 (16:41 -0400)
Implement lazy-populate policy (populate none) relying on kernel COW of
init values pages.

Signed-off-by: Mathieu Desnoyers <mathieu.desnoyers@efficios.com>
Change-Id: Ie4467f1806208271a5462605445abc2f0c1c1d38

include/rseq/mempool.h
src/rseq-mempool.c
tests/mempool_test.c

index 8dd5078038263b46a0aff73cbbbf73f1696799aa..d1564c52dee07cc15c1a32df8bc0446b1b34c20b 100644 (file)
@@ -555,6 +555,33 @@ int rseq_mempool_attr_set_max_nr_ranges(struct rseq_mempool_attr *attr,
 int rseq_mempool_attr_set_poison(struct rseq_mempool_attr *attr,
                uintptr_t poison);
 
+enum rseq_mempool_populate_policy {
+       /*
+        * RSEQ_MEMPOOL_POPULATE_NONE (default):
+        *   Do not populate pages for any of the CPUs when creating the
+        *   mempool. Rely on copy-on-write (COW) of per-cpu pages to
+        *   populate per-cpu pages from the initial values pages on
+        *   first write.
+        */
+       RSEQ_MEMPOOL_POPULATE_NONE = 0,
+       /*
+        * RSEQ_MEMPOOL_POPULATE_ALL:
+        *   Populate pages for all CPUs from 0 to (max_nr_cpus - 1)
+        *   when creating the mempool.
+        */
+       RSEQ_MEMPOOL_POPULATE_ALL = 1,
+};
+
+/*
+ * rseq_mempool_attr_set_populate_policy: Set pool page populate policy.
+ *
+ * Set page populate policy for the mempool.
+ *
+ * Returns 0 on success, -1 with errno=EINVAL if arguments are invalid.
+ */
+int rseq_mempool_attr_set_populate_policy(struct rseq_mempool_attr *attr,
+               enum rseq_mempool_populate_policy policy);
+
 /*
  * rseq_mempool_range_init_numa: NUMA initialization helper for memory range.
  *
index 8b25e0b886ff9c556920dd3686cb104ec974d84b..d2bc0c9de886ab0f3e462b582345a6f1577bfbba 100644 (file)
@@ -13,6 +13,7 @@
 #include <stdint.h>
 #include <stdbool.h>
 #include <stdio.h>
+#include <fcntl.h>
 
 #ifdef HAVE_LIBNUMA
 # include <numa.h>
@@ -89,6 +90,8 @@ struct rseq_mempool_attr {
 
        bool poison_set;
        uintptr_t poison;
+
+       enum rseq_mempool_populate_policy populate_policy;
 };
 
 struct rseq_mempool_range;
@@ -96,8 +99,23 @@ struct rseq_mempool_range;
 struct rseq_mempool_range {
        struct rseq_mempool_range *next;        /* Linked list of ranges. */
        struct rseq_mempool *pool;              /* Backward reference to container pool. */
+
+       /*
+        * Memory layout of a mempool range:
+        * - Header page (contains struct rseq_mempool_range at the very end),
+        * - Base of the per-cpu data, starting with CPU 0,
+        * - CPU 1,
+        * ...
+        * - CPU max_nr_cpus - 1
+        * - init values (unpopulated for RSEQ_MEMPOOL_POPULATE_ALL).
+        */
        void *header;
        void *base;
+       /*
+        * The init values contains malloc_init/zmalloc values.
+        * Pointer is NULL for RSEQ_MEMPOOL_POPULATE_ALL.
+        */
+       void *init;
        size_t next_unused;
 
        /* Pool range mmap/munmap */
@@ -143,6 +161,26 @@ struct rseq_mempool_set {
        struct rseq_mempool *entries[POOL_SET_NR_ENTRIES];
 };
 
+/*
+ * This memfd is used to implement the user COW behavior for the page
+ * protection scheme. memfd is a sparse virtual file. Its layout (in
+ * offset from beginning of file) matches the process address space
+ * (pointers directly converted to file offsets).
+ */
+struct rseq_memfd {
+       pthread_mutex_t lock;
+       size_t reserved_size;
+       unsigned int refcount;
+       int fd;
+};
+
+static struct rseq_memfd memfd = {
+       .lock = PTHREAD_MUTEX_INITIALIZER,
+       .reserved_size = 0,
+       .refcount = 0,
+       .fd = -1,
+};
+
 static
 const char *get_pool_name(const struct rseq_mempool *pool)
 {
@@ -156,15 +194,63 @@ void *__rseq_pool_range_percpu_ptr(const struct rseq_mempool_range *range, int c
        return range->base + (stride * cpu) + item_offset;
 }
 
+static
+void *__rseq_pool_range_init_ptr(const struct rseq_mempool_range *range,
+               uintptr_t item_offset)
+{
+       if (!range->init)
+               return NULL;
+       return range->init + item_offset;
+}
+
+static
+void __rseq_percpu *__rseq_free_list_to_percpu_ptr(const struct rseq_mempool *pool,
+               struct free_list_node *node)
+{
+       void __rseq_percpu *p = (void __rseq_percpu *) node;
+
+       if (pool->attr.populate_policy != RSEQ_MEMPOOL_POPULATE_ALL)
+               p -= pool->attr.max_nr_cpus * pool->attr.stride;
+       return p;
+}
+
+static
+struct free_list_node *__rseq_percpu_to_free_list_ptr(const struct rseq_mempool *pool,
+               void __rseq_percpu *p)
+{
+       if (pool->attr.populate_policy != RSEQ_MEMPOOL_POPULATE_ALL)
+               p += pool->attr.max_nr_cpus * pool->attr.stride;
+       return (struct free_list_node *) p;
+}
+
+static
+int memcmpbyte(const char *s, int c, size_t n)
+{
+       int res = 0;
+
+       while (n-- > 0)
+               if ((res = *(s++) - c) != 0)
+                       break;
+       return res;
+}
+
 static
 void rseq_percpu_zero_item(struct rseq_mempool *pool,
                struct rseq_mempool_range *range, uintptr_t item_offset)
 {
+       char *init_p = NULL;
        int i;
 
+       init_p = __rseq_pool_range_init_ptr(range, item_offset);
+       if (init_p)
+               memset(init_p, 0, pool->item_len);
        for (i = 0; i < pool->attr.max_nr_cpus; i++) {
                char *p = __rseq_pool_range_percpu_ptr(range, i,
                                item_offset, pool->attr.stride);
+
+               /* Update propagated */
+               if (init_p && !memcmpbyte(p, 0, pool->item_len))
+                       continue;
                memset(p, 0, pool->item_len);
        }
 }
@@ -174,29 +260,73 @@ void rseq_percpu_init_item(struct rseq_mempool *pool,
                struct rseq_mempool_range *range, uintptr_t item_offset,
                void *init_ptr, size_t init_len)
 {
+       char *init_p = NULL;
        int i;
 
+       init_p = __rseq_pool_range_init_ptr(range, item_offset);
+       if (init_p)
+               memcpy(init_p, init_ptr, init_len);
        for (i = 0; i < pool->attr.max_nr_cpus; i++) {
                char *p = __rseq_pool_range_percpu_ptr(range, i,
                                item_offset, pool->attr.stride);
+
+               /* Update propagated */
+               if (init_p && !memcmp(init_p, p, init_len))
+                       continue;
                memcpy(p, init_ptr, init_len);
        }
 }
 
+static
+void rseq_poison_item(void *p, size_t item_len, uintptr_t poison)
+{
+       size_t offset;
+
+       for (offset = 0; offset < item_len; offset += sizeof(uintptr_t))
+               *((uintptr_t *) (p + offset)) = poison;
+}
+
 static
 void rseq_percpu_poison_item(struct rseq_mempool *pool,
                struct rseq_mempool_range *range, uintptr_t item_offset)
 {
        uintptr_t poison = pool->attr.poison;
+       char *init_p = NULL;
        int i;
 
+       init_p = __rseq_pool_range_init_ptr(range, item_offset);
+       if (init_p)
+               rseq_poison_item(init_p, pool->item_len, poison);
        for (i = 0; i < pool->attr.max_nr_cpus; i++) {
                char *p = __rseq_pool_range_percpu_ptr(range, i,
                                item_offset, pool->attr.stride);
-               size_t offset;
 
-               for (offset = 0; offset < pool->item_len; offset += sizeof(uintptr_t))
-                       *((uintptr_t *) (p + offset)) = poison;
+               /* Update propagated */
+               if (init_p && !memcmp(init_p, p, pool->item_len))
+                       continue;
+               rseq_poison_item(p, pool->item_len, poison);
+       }
+}
+
+/* Always inline for __builtin_return_address(0). */
+static inline __attribute__((always_inline))
+void rseq_check_poison_item(const struct rseq_mempool *pool, uintptr_t item_offset,
+               void *p, size_t item_len, uintptr_t poison, bool skip_freelist_ptr)
+{
+       size_t offset;
+
+       for (offset = 0; offset < item_len; offset += sizeof(uintptr_t)) {
+               uintptr_t v;
+
+               /* Skip poison check for free-list pointer. */
+               if (skip_freelist_ptr && offset == 0)
+                       continue;
+               v = *((uintptr_t *) (p + offset));
+               if (v != poison) {
+                       fprintf(stderr, "%s: Poison corruption detected (0x%lx) for pool: \"%s\" (%p), item offset: %zu, caller: %p.\n",
+                               __func__, (unsigned long) v, get_pool_name(pool), pool, item_offset, (void *) __builtin_return_address(0));
+                       abort();
+               }
        }
 }
 
@@ -206,28 +336,29 @@ void rseq_percpu_check_poison_item(const struct rseq_mempool *pool,
                const struct rseq_mempool_range *range, uintptr_t item_offset)
 {
        uintptr_t poison = pool->attr.poison;
+       char *init_p;
        int i;
 
        if (!pool->attr.robust_set)
                return;
+       init_p = __rseq_pool_range_init_ptr(range, item_offset);
+       if (init_p)
+               rseq_check_poison_item(pool, item_offset, init_p, pool->item_len, poison, true);
        for (i = 0; i < pool->attr.max_nr_cpus; i++) {
                char *p = __rseq_pool_range_percpu_ptr(range, i,
                                item_offset, pool->attr.stride);
-               size_t offset;
-
-               for (offset = 0; offset < pool->item_len; offset += sizeof(uintptr_t)) {
-                       uintptr_t v;
-
-                       /* Skip poison check for free-list pointer. */
-                       if (i == 0 && offset == 0)
-                               continue;
-                       v = *((uintptr_t *) (p + offset));
-                       if (v != poison) {
-                               fprintf(stderr, "%s: Poison corruption detected (0x%lx) for pool: \"%s\" (%p), item offset: %zu, caller: %p.\n",
-                                       __func__, (unsigned long) v, get_pool_name(pool), pool, item_offset, (void *) __builtin_return_address(0));
-                               abort();
-                       }
-               }
+               /*
+                * When the free list is embedded in the init values
+                * memory (populate none), it is visible from the init
+                * values memory mapping as well as per-cpu private
+                * mappings before they COW.
+                *
+                * When the free list is embedded in CPU 0 mapping
+                * (populate all), only this CPU must skip the free list
+                * nodes when checking poison.
+                */
+               rseq_check_poison_item(pool, item_offset, p, pool->item_len, poison,
+                       init_p == NULL ? (i == 0) : true);
        }
 }
 
@@ -333,9 +464,10 @@ int create_alloc_bitmap(struct rseq_mempool *pool, struct rseq_mempool_range *ra
 }
 
 static
-bool addr_in_pool(const struct rseq_mempool *pool, void *addr)
+bool percpu_addr_in_pool(const struct rseq_mempool *pool, void __rseq_percpu *_addr)
 {
        struct rseq_mempool_range *range;
+       void *addr = (void *) _addr;
 
        for (range = pool->range_list; range; range = range->next) {
                if (addr >= range->base && addr < range->base + range->next_unused)
@@ -366,8 +498,6 @@ void check_free_list(const struct rseq_mempool *pool)
             prev = node,
             node = node->next) {
 
-               void *node_addr = node;
-
                if (traversal_iteration >= max_list_traversal) {
                        fprintf(stderr, "%s: Corrupted free-list; Possibly infinite loop in pool \"%s\" (%p), caller %p.\n",
                                __func__, get_pool_name(pool), pool, __builtin_return_address(0));
@@ -375,7 +505,7 @@ void check_free_list(const struct rseq_mempool *pool)
                }
 
                /* Node is out of range. */
-               if (!addr_in_pool(pool, node_addr)) {
+               if (!percpu_addr_in_pool(pool, __rseq_free_list_to_percpu_ptr(pool, node))) {
                        if (prev)
                                fprintf(stderr, "%s: Corrupted free-list node %p -> [out-of-range %p] in pool \"%s\" (%p), caller %p.\n",
                                        __func__, prev, node, get_pool_name(pool), pool, __builtin_return_address(0));
@@ -442,6 +572,7 @@ void destroy_alloc_bitmap(struct rseq_mempool *pool, struct rseq_mempool_range *
        }
 
        free(bitmap);
+       range->alloc_bitmap = NULL;
 }
 
 /* Always inline for __builtin_return_address(0). */
@@ -449,7 +580,21 @@ static inline __attribute__((always_inline))
 int rseq_mempool_range_destroy(struct rseq_mempool *pool,
                struct rseq_mempool_range *range)
 {
+       int ret = 0;
+
        destroy_alloc_bitmap(pool, range);
+
+       /*
+        * Punch a hole into memfd where the init values used to be.
+        */
+       if (range->init) {
+               ret = fallocate(memfd.fd, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE,
+                       (off_t) range->init, pool->attr.stride);
+               if (ret)
+                       return ret;
+               range->init = NULL;
+       }
+
        /* range is a header located one page before the aligned mapping. */
        return pool->attr.munmap_func(pool->attr.mmap_priv, range->mmap_addr, range->mmap_len);
 }
@@ -536,6 +681,26 @@ alloc_error:
        return ptr;
 }
 
+static
+int rseq_memfd_reserve_init(void *init, size_t init_len)
+{
+       int ret = 0;
+       size_t reserve_len;
+
+       pthread_mutex_lock(&memfd.lock);
+       reserve_len = (size_t) init + init_len;
+       if (reserve_len > memfd.reserved_size) {
+               if (ftruncate(memfd.fd, (off_t) reserve_len)) {
+                       ret = -1;
+                       goto unlock;
+               }
+               memfd.reserved_size = reserve_len;
+       }
+unlock:
+       pthread_mutex_unlock(&memfd.lock);
+       return ret;
+}
+
 static
 struct rseq_mempool_range *rseq_mempool_range_create(struct rseq_mempool *pool)
 {
@@ -543,6 +708,7 @@ struct rseq_mempool_range *rseq_mempool_range_create(struct rseq_mempool *pool)
        unsigned long page_size;
        void *header;
        void *base;
+       size_t range_len;       /* Range len excludes header. */
 
        if (pool->attr.max_nr_ranges &&
                        pool->nr_ranges >= pool->attr.max_nr_ranges) {
@@ -551,18 +717,51 @@ struct rseq_mempool_range *rseq_mempool_range_create(struct rseq_mempool *pool)
        }
        page_size = rseq_get_page_len();
 
+       range_len = pool->attr.stride * pool->attr.max_nr_cpus;
+       if (pool->attr.populate_policy != RSEQ_MEMPOOL_POPULATE_ALL)
+               range_len += pool->attr.stride; /* init values */
        base = aligned_mmap_anonymous(pool, page_size,
-                       pool->attr.stride * pool->attr.max_nr_cpus,
+                       range_len,
                        pool->attr.stride,
                        &header, page_size);
        if (!base)
                return NULL;
        range = (struct rseq_mempool_range *) (base - RANGE_HEADER_OFFSET);
        range->pool = pool;
-       range->base = base;
        range->header = header;
+       range->base = base;
        range->mmap_addr = header;
-       range->mmap_len = page_size + (pool->attr.stride * pool->attr.max_nr_cpus);
+       range->mmap_len = page_size + range_len;
+
+       if (pool->attr.populate_policy != RSEQ_MEMPOOL_POPULATE_ALL) {
+               range->init = base + (pool->attr.stride * pool->attr.max_nr_cpus);
+               /* Populate init values pages from memfd */
+               if (rseq_memfd_reserve_init(range->init, pool->attr.stride))
+                       goto error_alloc;
+               if (mmap(range->init, pool->attr.stride, PROT_READ | PROT_WRITE,
+                               MAP_SHARED | MAP_FIXED, memfd.fd,
+                               (off_t) range->init) != (void *) range->init) {
+                       goto error_alloc;
+               }
+               assert(pool->attr.type == MEMPOOL_TYPE_PERCPU);
+               /*
+                * Map per-cpu memory as private COW mappings of init values.
+                */
+               {
+                       int cpu;
+
+                       for (cpu = 0; cpu < pool->attr.max_nr_cpus; cpu++) {
+                               void *p = base + (pool->attr.stride * cpu);
+                               size_t len = pool->attr.stride;
+
+                               if (mmap(p, len, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_FIXED,
+                                               memfd.fd, (off_t) range->init) != (void *) p) {
+                                       goto error_alloc;
+                               }
+                       }
+               }
+       }
+
        if (pool->attr.robust_set) {
                if (create_alloc_bitmap(pool, range))
                        goto error_alloc;
@@ -599,6 +798,48 @@ error_alloc:
        return NULL;
 }
 
+static
+int rseq_mempool_memfd_ref(struct rseq_mempool *pool)
+{
+       int ret = 0;
+
+       if (pool->attr.populate_policy == RSEQ_MEMPOOL_POPULATE_ALL)
+               return 0;
+
+       pthread_mutex_lock(&memfd.lock);
+       if (memfd.refcount == 0) {
+               memfd.fd = memfd_create("mempool", MFD_CLOEXEC);
+               if (memfd.fd < 0) {
+                       perror("memfd_create");
+                       ret = -1;
+                       goto unlock;
+               }
+       }
+       memfd.refcount++;
+unlock:
+       pthread_mutex_unlock(&memfd.lock);
+       return ret;
+}
+
+static
+void rseq_mempool_memfd_unref(struct rseq_mempool *pool)
+{
+       if (pool->attr.populate_policy == RSEQ_MEMPOOL_POPULATE_ALL)
+               return;
+
+       pthread_mutex_lock(&memfd.lock);
+       if (memfd.refcount == 1) {
+               if (close(memfd.fd)) {
+                       perror("close");
+                       abort();
+               }
+               memfd.fd = -1;
+               memfd.reserved_size = 0;
+       }
+       memfd.refcount--;
+       pthread_mutex_unlock(&memfd.lock);
+}
+
 int rseq_mempool_destroy(struct rseq_mempool *pool)
 {
        struct rseq_mempool_range *range, *next_range;
@@ -615,6 +856,7 @@ int rseq_mempool_destroy(struct rseq_mempool *pool)
                /* Update list head to keep list coherent in case of partial failure. */
                pool->range_list = next_range;
        }
+       rseq_mempool_memfd_unref(pool);
        pthread_mutex_destroy(&pool->lock);
        free(pool->name);
        free(pool);
@@ -665,6 +907,8 @@ struct rseq_mempool *rseq_mempool_create(const char *pool_name,
                }
                break;
        case MEMPOOL_TYPE_GLOBAL:
+               /* Override populate policy for global type. */
+               attr.populate_policy = RSEQ_MEMPOOL_POPULATE_ALL;
                /* Use a 1-cpu pool for global mempool type. */
                attr.max_nr_cpus = 1;
                break;
@@ -690,6 +934,9 @@ struct rseq_mempool *rseq_mempool_create(const char *pool_name,
        pool->item_len = item_len;
        pool->item_order = order;
 
+       if (rseq_mempool_memfd_ref(pool))
+               goto error_alloc;
+
        pool->range_list = rseq_mempool_range_create(pool);
        if (!pool->range_list)
                goto error_alloc;
@@ -748,15 +995,16 @@ void __rseq_percpu *__rseq_percpu_malloc(struct rseq_mempool *pool,
        /* Get first entry from free list. */
        node = pool->free_list_head;
        if (node != NULL) {
-               uintptr_t ptr = (uintptr_t) node;
-               void *range_base = (void *) (ptr & (~(pool->attr.stride - 1)));
+               void *range_base, *ptr;
 
+               ptr = __rseq_free_list_to_percpu_ptr(pool, node);
+               range_base = (void *) ((uintptr_t) ptr & (~(pool->attr.stride - 1)));
                range = (struct rseq_mempool_range *) (range_base - RANGE_HEADER_OFFSET);
                /* Remove node from free list (update head). */
                pool->free_list_head = node->next;
-               item_offset = (uintptr_t) ((void *) node - range_base);
+               item_offset = (uintptr_t) (ptr - range_base);
                rseq_percpu_check_poison_item(pool, range, item_offset);
-               addr = (void __rseq_percpu *) node;
+               addr = __rseq_free_list_to_percpu_ptr(pool, node);
                goto end;
        }
        /*
@@ -851,11 +1099,11 @@ void librseq_mempool_percpu_free(void __rseq_percpu *_ptr, size_t stride)
        head = pool->free_list_head;
        if (pool->attr.poison_set)
                rseq_percpu_poison_item(pool, range, item_offset);
-       /* Free-list is in CPU 0 range. */
-       item = (struct free_list_node *) ptr;
+       item = __rseq_percpu_to_free_list_ptr(pool, _ptr);
        /*
         * Setting the next pointer will overwrite the first uintptr_t
-        * poison for CPU 0.
+        * poison for either CPU 0 (populate all) or init data (populate
+        * none).
         */
        item->next = head;
        pool->free_list_head = item;
@@ -1068,6 +1316,17 @@ int rseq_mempool_attr_set_poison(struct rseq_mempool_attr *attr,
        return 0;
 }
 
+int rseq_mempool_attr_set_populate_policy(struct rseq_mempool_attr *attr,
+               enum rseq_mempool_populate_policy policy)
+{
+       if (!attr) {
+               errno = EINVAL;
+               return -1;
+       }
+       attr->populate_policy = policy;
+       return 0;
+}
+
 int rseq_mempool_get_max_nr_cpus(struct rseq_mempool *mempool)
 {
        if (!mempool || mempool->attr.type != MEMPOOL_TYPE_PERCPU) {
index 7a8c79b5b8e92fa0fd8a0f2c3ff49129828ac760..81a647292341a67075b5de154ed985463b4ff1e1 100644 (file)
@@ -64,6 +64,9 @@ static void test_mempool_fill(unsigned long max_nr_ranges, size_t stride)
        ok(ret == 0, "Setting mempool max_nr_ranges=%lu", max_nr_ranges);
        ret = rseq_mempool_attr_set_poison(attr, POISON_VALUE);
        ok(ret == 0, "Setting mempool poison");
+       ret = rseq_mempool_attr_set_populate_policy(attr,
+                       RSEQ_MEMPOOL_POPULATE_ALL);
+       ok(ret == 0, "Setting mempool populate policy to ALL");
        mempool = rseq_mempool_create("test_data",
                        sizeof(struct test_data), attr);
        ok(mempool, "Create mempool of size %zu", stride);
@@ -249,6 +252,10 @@ static void run_robust_tests(void)
        ret = rseq_mempool_attr_set_percpu(attr, RSEQ_MEMPOOL_STRIDE, 1);
        ok(ret == 0, "Setting mempool percpu type");
 
+       ret = rseq_mempool_attr_set_populate_policy(attr,
+                       RSEQ_MEMPOOL_POPULATE_ALL);
+       ok(ret == 0, "Setting mempool populate policy to ALL");
+
        pool = rseq_mempool_create("mempool-robust",
                                sizeof(struct test_data), attr);
 
This page took 0.036339 seconds and 4 git commands to generate.