mm: memcontrol: do not uncharge old page in page cache replacement
[deliverable/linux.git] / mm / memcontrol.c
index fc10620967c79d7b8fbbcc0f82fd9804b1562f13..bf35bff282fc09412b34929f9829bfe86b88b361 100644 (file)
@@ -66,7 +66,6 @@
 #include "internal.h"
 #include <net/sock.h>
 #include <net/ip.h>
-#include <net/tcp_memcontrol.h>
 #include "slab.h"
 
 #include <asm/uaccess.h>
 struct cgroup_subsys memory_cgrp_subsys __read_mostly;
 EXPORT_SYMBOL(memory_cgrp_subsys);
 
+struct mem_cgroup *root_mem_cgroup __read_mostly;
+
 #define MEM_CGROUP_RECLAIM_RETRIES     5
-static struct mem_cgroup *root_mem_cgroup __read_mostly;
-struct cgroup_subsys_state *mem_cgroup_root_css __read_mostly;
+
+/* Socket memory accounting disabled? */
+static bool cgroup_memory_nosocket;
+
+/* Kernel memory accounting disabled? */
+static bool cgroup_memory_nokmem;
 
 /* Whether the swap controller is active */
 #ifdef CONFIG_MEMCG_SWAP
@@ -87,6 +92,12 @@ int do_swap_account __read_mostly;
 #define do_swap_account                0
 #endif
 
+/* Whether legacy memory+swap accounting is active */
+static bool do_memsw_account(void)
+{
+       return !cgroup_subsys_on_dfl(memory_cgrp_subsys) && do_swap_account;
+}
+
 static const char * const mem_cgroup_stat_names[] = {
        "cache",
        "rss",
@@ -230,6 +241,7 @@ enum res_type {
        _MEMSWAP,
        _OOM_TYPE,
        _KMEM,
+       _TCP,
 };
 
 #define MEMFILE_PRIVATE(x, val)        ((x) << 16 | (val))
@@ -238,13 +250,6 @@ enum res_type {
 /* Used for OOM nofiier */
 #define OOM_CONTROL            (0)
 
-/*
- * The memcg_create_mutex will be held whenever a new cgroup is created.
- * As a consequence, any change that needs to protect against new child cgroups
- * appearing has to hold it as well.
- */
-static DEFINE_MUTEX(memcg_create_mutex);
-
 /* Some nice accessors for the vmpressure. */
 struct vmpressure *memcg_to_vmpressure(struct mem_cgroup *memcg)
 {
@@ -288,65 +293,7 @@ static inline struct mem_cgroup *mem_cgroup_from_id(unsigned short id)
        return mem_cgroup_from_css(css);
 }
 
-/* Writing them here to avoid exposing memcg's inner layout */
-#if defined(CONFIG_INET) && defined(CONFIG_MEMCG_KMEM)
-
-void sock_update_memcg(struct sock *sk)
-{
-       if (mem_cgroup_sockets_enabled) {
-               struct mem_cgroup *memcg;
-               struct cg_proto *cg_proto;
-
-               BUG_ON(!sk->sk_prot->proto_cgroup);
-
-               /* Socket cloning can throw us here with sk_cgrp already
-                * filled. It won't however, necessarily happen from
-                * process context. So the test for root memcg given
-                * the current task's memcg won't help us in this case.
-                *
-                * Respecting the original socket's memcg is a better
-                * decision in this case.
-                */
-               if (sk->sk_cgrp) {
-                       BUG_ON(mem_cgroup_is_root(sk->sk_cgrp->memcg));
-                       css_get(&sk->sk_cgrp->memcg->css);
-                       return;
-               }
-
-               rcu_read_lock();
-               memcg = mem_cgroup_from_task(current);
-               cg_proto = sk->sk_prot->proto_cgroup(memcg);
-               if (cg_proto && test_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags) &&
-                   css_tryget_online(&memcg->css)) {
-                       sk->sk_cgrp = cg_proto;
-               }
-               rcu_read_unlock();
-       }
-}
-EXPORT_SYMBOL(sock_update_memcg);
-
-void sock_release_memcg(struct sock *sk)
-{
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct mem_cgroup *memcg;
-               WARN_ON(!sk->sk_cgrp->memcg);
-               memcg = sk->sk_cgrp->memcg;
-               css_put(&sk->sk_cgrp->memcg->css);
-       }
-}
-
-struct cg_proto *tcp_proto_cgroup(struct mem_cgroup *memcg)
-{
-       if (!memcg || mem_cgroup_is_root(memcg))
-               return NULL;
-
-       return &memcg->tcp_mem;
-}
-EXPORT_SYMBOL(tcp_proto_cgroup);
-
-#endif
-
-#ifdef CONFIG_MEMCG_KMEM
+#ifndef CONFIG_SLOB
 /*
  * This will be the memcg's index in each cache's ->memcg_params.memcg_caches.
  * The main reason for not using cgroup id for this:
@@ -395,10 +342,10 @@ void memcg_put_cache_ids(void)
  * conditional to this static branch, we'll have to allow modules that does
  * kmem_cache_alloc and the such to see this symbol as well
  */
-struct static_key memcg_kmem_enabled_key;
+DEFINE_STATIC_KEY_FALSE(memcg_kmem_enabled_key);
 EXPORT_SYMBOL(memcg_kmem_enabled_key);
 
-#endif /* CONFIG_MEMCG_KMEM */
+#endif /* !CONFIG_SLOB */
 
 static struct mem_cgroup_per_zone *
 mem_cgroup_zone_zoneinfo(struct mem_cgroup *memcg, struct zone *zone)
@@ -419,26 +366,16 @@ mem_cgroup_zone_zoneinfo(struct mem_cgroup *memcg, struct zone *zone)
  *
  * If memcg is bound to a traditional hierarchy, the css of root_mem_cgroup
  * is returned.
- *
- * XXX: The above description of behavior on the default hierarchy isn't
- * strictly true yet as replace_page_cache_page() can modify the
- * association before @page is released even on the default hierarchy;
- * however, the current and planned usages don't mix the the two functions
- * and replace_page_cache_page() will soon be updated to make the invariant
- * actually true.
  */
 struct cgroup_subsys_state *mem_cgroup_css_from_page(struct page *page)
 {
        struct mem_cgroup *memcg;
 
-       rcu_read_lock();
-
        memcg = page->mem_cgroup;
 
        if (!memcg || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
                memcg = root_mem_cgroup;
 
-       rcu_read_unlock();
        return &memcg->css;
 }
 
@@ -696,7 +633,7 @@ static unsigned long mem_cgroup_read_events(struct mem_cgroup *memcg,
 
 static void mem_cgroup_charge_statistics(struct mem_cgroup *memcg,
                                         struct page *page,
-                                        int nr_pages)
+                                        bool compound, int nr_pages)
 {
        /*
         * Here, RSS means 'mapped anon' and anon's SwapCache. Shmem/tmpfs is
@@ -709,9 +646,11 @@ static void mem_cgroup_charge_statistics(struct mem_cgroup *memcg,
                __this_cpu_add(memcg->stat->count[MEM_CGROUP_STAT_CACHE],
                                nr_pages);
 
-       if (PageTransHuge(page))
+       if (compound) {
+               VM_BUG_ON_PAGE(!PageTransHuge(page), page);
                __this_cpu_add(memcg->stat->count[MEM_CGROUP_STAT_RSS_HUGE],
                                nr_pages);
+       }
 
        /* pagein of a big page is an event. So, ignore page size */
        if (nr_pages > 0)
@@ -946,17 +885,8 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
                if (css == &root->css)
                        break;
 
-               if (css_tryget(css)) {
-                       /*
-                        * Make sure the memcg is initialized:
-                        * mem_cgroup_css_online() orders the the
-                        * initialization against setting the flag.
-                        */
-                       if (smp_load_acquire(&memcg->initialized))
-                               break;
-
-                       css_put(css);
-               }
+               if (css_tryget(css))
+                       break;
 
                memcg = NULL;
        }
@@ -1162,9 +1092,6 @@ bool task_in_mem_cgroup(struct task_struct *task, struct mem_cgroup *memcg)
        return ret;
 }
 
-#define mem_cgroup_from_counter(counter, member)       \
-       container_of(counter, struct mem_cgroup, member)
-
 /**
  * mem_cgroup_margin - calculate chargeable space of a memory cgroup
  * @memcg: the memory cgroup
@@ -1183,7 +1110,7 @@ static unsigned long mem_cgroup_margin(struct mem_cgroup *memcg)
        if (count < limit)
                margin = limit - count;
 
-       if (do_swap_account) {
+       if (do_memsw_account()) {
                count = page_counter_read(&memcg->memsw);
                limit = READ_ONCE(memcg->memsw.limit);
                if (count <= limit)
@@ -1325,9 +1252,12 @@ static unsigned long mem_cgroup_get_limit(struct mem_cgroup *memcg)
        limit = memcg->memory.limit;
        if (mem_cgroup_swappiness(memcg)) {
                unsigned long memsw_limit;
+               unsigned long swap_limit;
 
                memsw_limit = memcg->memsw.limit;
-               limit = min(limit + total_swap_pages, memsw_limit);
+               swap_limit = memcg->swap.limit;
+               swap_limit = min(swap_limit, (unsigned long)total_swap_pages);
+               limit = min(limit + swap_limit, memsw_limit);
        }
        return limit;
 }
@@ -1909,7 +1839,7 @@ static void drain_stock(struct memcg_stock_pcp *stock)
 
        if (stock->nr_pages) {
                page_counter_uncharge(&old->memory, stock->nr_pages);
-               if (do_swap_account)
+               if (do_memsw_account())
                        page_counter_uncharge(&old->memsw, stock->nr_pages);
                css_put_many(&old->css, stock->nr_pages);
                stock->nr_pages = 0;
@@ -1997,6 +1927,26 @@ static int memcg_cpu_hotplug_callback(struct notifier_block *nb,
        return NOTIFY_OK;
 }
 
+static void reclaim_high(struct mem_cgroup *memcg,
+                        unsigned int nr_pages,
+                        gfp_t gfp_mask)
+{
+       do {
+               if (page_counter_read(&memcg->memory) <= memcg->high)
+                       continue;
+               mem_cgroup_events(memcg, MEMCG_HIGH, 1);
+               try_to_free_mem_cgroup_pages(memcg, nr_pages, gfp_mask, true);
+       } while ((memcg = parent_mem_cgroup(memcg)));
+}
+
+static void high_work_func(struct work_struct *work)
+{
+       struct mem_cgroup *memcg;
+
+       memcg = container_of(work, struct mem_cgroup, high_work);
+       reclaim_high(memcg, CHARGE_BATCH, GFP_KERNEL);
+}
+
 /*
  * Scheduled by try_charge() to be executed from the userland return path
  * and reclaims memory over the high limit.
@@ -2004,20 +1954,13 @@ static int memcg_cpu_hotplug_callback(struct notifier_block *nb,
 void mem_cgroup_handle_over_high(void)
 {
        unsigned int nr_pages = current->memcg_nr_pages_over_high;
-       struct mem_cgroup *memcg, *pos;
+       struct mem_cgroup *memcg;
 
        if (likely(!nr_pages))
                return;
 
-       pos = memcg = get_mem_cgroup_from_mm(current->mm);
-
-       do {
-               if (page_counter_read(&pos->memory) <= pos->high)
-                       continue;
-               mem_cgroup_events(pos, MEMCG_HIGH, 1);
-               try_to_free_mem_cgroup_pages(pos, nr_pages, GFP_KERNEL, true);
-       } while ((pos = parent_mem_cgroup(pos)));
-
+       memcg = get_mem_cgroup_from_mm(current->mm);
+       reclaim_high(memcg, nr_pages, GFP_KERNEL);
        css_put(&memcg->css);
        current->memcg_nr_pages_over_high = 0;
 }
@@ -2039,11 +1982,11 @@ retry:
        if (consume_stock(memcg, nr_pages))
                return 0;
 
-       if (!do_swap_account ||
+       if (!do_memsw_account() ||
            page_counter_try_charge(&memcg->memsw, batch, &counter)) {
                if (page_counter_try_charge(&memcg->memory, batch, &counter))
                        goto done_restock;
-               if (do_swap_account)
+               if (do_memsw_account())
                        page_counter_uncharge(&memcg->memsw, batch);
                mem_over_limit = mem_cgroup_from_counter(counter, memory);
        } else {
@@ -2130,7 +2073,7 @@ force:
         * temporarily by force charging it.
         */
        page_counter_charge(&memcg->memory, nr_pages);
-       if (do_swap_account)
+       if (do_memsw_account())
                page_counter_charge(&memcg->memsw, nr_pages);
        css_get_many(&memcg->css, nr_pages);
 
@@ -2152,6 +2095,11 @@ done_restock:
         */
        do {
                if (page_counter_read(&memcg->memory) > memcg->high) {
+                       /* Don't bother a random interrupted task */
+                       if (in_interrupt()) {
+                               schedule_work(&memcg->high_work);
+                               break;
+                       }
                        current->memcg_nr_pages_over_high += batch;
                        set_notify_resume(current);
                        break;
@@ -2167,7 +2115,7 @@ static void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
                return;
 
        page_counter_uncharge(&memcg->memory, nr_pages);
-       if (do_swap_account)
+       if (do_memsw_account())
                page_counter_uncharge(&memcg->memsw, nr_pages);
 
        css_put_many(&memcg->css, nr_pages);
@@ -2238,7 +2186,7 @@ static void commit_charge(struct page *page, struct mem_cgroup *memcg,
                unlock_page_lru(page, isolated);
 }
 
-#ifdef CONFIG_MEMCG_KMEM
+#ifndef CONFIG_SLOB
 static int memcg_alloc_cache_id(void)
 {
        int id, size;
@@ -2356,7 +2304,7 @@ static void memcg_schedule_kmem_cache_create(struct mem_cgroup *memcg,
  * Can't be called in interrupt context or from kernel threads.
  * This function needs to be called with rcu_read_lock() held.
  */
-struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep)
+struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep, gfp_t gfp)
 {
        struct mem_cgroup *memcg;
        struct kmem_cache *memcg_cachep;
@@ -2364,6 +2312,12 @@ struct kmem_cache *__memcg_kmem_get_cache(struct kmem_cache *cachep)
 
        VM_BUG_ON(!is_root_cache(cachep));
 
+       if (cachep->flags & SLAB_ACCOUNT)
+               gfp |= __GFP_ACCOUNT;
+
+       if (!(gfp & __GFP_ACCOUNT))
+               return cachep;
+
        if (current->memcg_kmem_skip_account)
                return cachep;
 
@@ -2407,16 +2361,17 @@ int __memcg_kmem_charge_memcg(struct page *page, gfp_t gfp, int order,
        struct page_counter *counter;
        int ret;
 
-       if (!memcg_kmem_is_active(memcg))
+       if (!memcg_kmem_online(memcg))
                return 0;
 
-       if (!page_counter_try_charge(&memcg->kmem, nr_pages, &counter))
-               return -ENOMEM;
-
        ret = try_charge(memcg, gfp, nr_pages);
-       if (ret) {
-               page_counter_uncharge(&memcg->kmem, nr_pages);
+       if (ret)
                return ret;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) &&
+           !page_counter_try_charge(&memcg->kmem, nr_pages, &counter)) {
+               cancel_charge(memcg, nr_pages);
+               return -ENOMEM;
        }
 
        page->mem_cgroup = memcg;
@@ -2445,23 +2400,23 @@ void __memcg_kmem_uncharge(struct page *page, int order)
 
        VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
 
-       page_counter_uncharge(&memcg->kmem, nr_pages);
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               page_counter_uncharge(&memcg->kmem, nr_pages);
+
        page_counter_uncharge(&memcg->memory, nr_pages);
-       if (do_swap_account)
+       if (do_memsw_account())
                page_counter_uncharge(&memcg->memsw, nr_pages);
 
        page->mem_cgroup = NULL;
        css_put_many(&memcg->css, nr_pages);
 }
-#endif /* CONFIG_MEMCG_KMEM */
+#endif /* !CONFIG_SLOB */
 
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
 
 /*
  * Because tail pages are not marked as "used", set it. We're under
- * zone->lru_lock, 'splitting on pmd' and compound_lock.
- * charge/uncharge will be never happen and move_account() is done under
- * compound_lock(), so we don't have to take care of races.
+ * zone->lru_lock and migration entries setup in all page mappings.
  */
 void mem_cgroup_split_huge_fixup(struct page *head)
 {
@@ -2715,14 +2670,6 @@ static inline bool memcg_has_children(struct mem_cgroup *memcg)
 {
        bool ret;
 
-       /*
-        * The lock does not prevent addition or deletion of children, but
-        * it prevents a new child from being initialized based on this
-        * parent in css_online(), so it's enough to decide whether
-        * hierarchically inherited attributes can still be changed or not.
-        */
-       lockdep_assert_held(&memcg_create_mutex);
-
        rcu_read_lock();
        ret = css_next_child(NULL, &memcg->css);
        rcu_read_unlock();
@@ -2785,10 +2732,8 @@ static int mem_cgroup_hierarchy_write(struct cgroup_subsys_state *css,
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
        struct mem_cgroup *parent_memcg = mem_cgroup_from_css(memcg->css.parent);
 
-       mutex_lock(&memcg_create_mutex);
-
        if (memcg->use_hierarchy == val)
-               goto out;
+               return 0;
 
        /*
         * If parent's use_hierarchy is set, we can't make any modifications
@@ -2807,9 +2752,6 @@ static int mem_cgroup_hierarchy_write(struct cgroup_subsys_state *css,
        } else
                retval = -EINVAL;
 
-out:
-       mutex_unlock(&memcg_create_mutex);
-
        return retval;
 }
 
@@ -2867,6 +2809,9 @@ static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
        case _KMEM:
                counter = &memcg->kmem;
                break;
+       case _TCP:
+               counter = &memcg->tcpmem;
+               break;
        default:
                BUG();
        }
@@ -2891,103 +2836,180 @@ static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
        }
 }
 
-#ifdef CONFIG_MEMCG_KMEM
-static int memcg_activate_kmem(struct mem_cgroup *memcg,
-                              unsigned long nr_pages)
+#ifndef CONFIG_SLOB
+static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
-       int err = 0;
        int memcg_id;
 
        BUG_ON(memcg->kmemcg_id >= 0);
-       BUG_ON(memcg->kmem_acct_activated);
-       BUG_ON(memcg->kmem_acct_active);
-
-       /*
-        * For simplicity, we won't allow this to be disabled.  It also can't
-        * be changed if the cgroup has children already, or if tasks had
-        * already joined.
-        *
-        * If tasks join before we set the limit, a person looking at
-        * kmem.usage_in_bytes will have no way to determine when it took
-        * place, which makes the value quite meaningless.
-        *
-        * After it first became limited, changes in the value of the limit are
-        * of course permitted.
-        */
-       mutex_lock(&memcg_create_mutex);
-       if (cgroup_is_populated(memcg->css.cgroup) ||
-           (memcg->use_hierarchy && memcg_has_children(memcg)))
-               err = -EBUSY;
-       mutex_unlock(&memcg_create_mutex);
-       if (err)
-               goto out;
+       BUG_ON(memcg->kmem_state);
 
        memcg_id = memcg_alloc_cache_id();
-       if (memcg_id < 0) {
-               err = memcg_id;
-               goto out;
-       }
+       if (memcg_id < 0)
+               return memcg_id;
 
+       static_branch_inc(&memcg_kmem_enabled_key);
        /*
-        * We couldn't have accounted to this cgroup, because it hasn't got
-        * activated yet, so this should succeed.
-        */
-       err = page_counter_limit(&memcg->kmem, nr_pages);
-       VM_BUG_ON(err);
-
-       static_key_slow_inc(&memcg_kmem_enabled_key);
-       /*
-        * A memory cgroup is considered kmem-active as soon as it gets
+        * A memory cgroup is considered kmem-online as soon as it gets
         * kmemcg_id. Setting the id after enabling static branching will
         * guarantee no one starts accounting before all call sites are
         * patched.
         */
        memcg->kmemcg_id = memcg_id;
-       memcg->kmem_acct_activated = true;
-       memcg->kmem_acct_active = true;
-out:
-       return err;
+       memcg->kmem_state = KMEM_ONLINE;
+
+       return 0;
 }
 
-static int memcg_update_kmem_limit(struct mem_cgroup *memcg,
-                                  unsigned long limit)
+static int memcg_propagate_kmem(struct mem_cgroup *parent,
+                               struct mem_cgroup *memcg)
 {
-       int ret;
+       int ret = 0;
 
        mutex_lock(&memcg_limit_mutex);
-       if (!memcg_kmem_is_active(memcg))
-               ret = memcg_activate_kmem(memcg, limit);
-       else
-               ret = page_counter_limit(&memcg->kmem, limit);
+       /*
+        * If the parent cgroup is not kmem-online now, it cannot be
+        * onlined after this point, because it has at least one child
+        * already.
+        */
+       if (memcg_kmem_online(parent) ||
+           (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nokmem))
+               ret = memcg_online_kmem(memcg);
        mutex_unlock(&memcg_limit_mutex);
        return ret;
 }
 
-static int memcg_propagate_kmem(struct mem_cgroup *memcg)
+static void memcg_offline_kmem(struct mem_cgroup *memcg)
 {
-       int ret = 0;
-       struct mem_cgroup *parent = parent_mem_cgroup(memcg);
+       struct cgroup_subsys_state *css;
+       struct mem_cgroup *parent, *child;
+       int kmemcg_id;
+
+       if (memcg->kmem_state != KMEM_ONLINE)
+               return;
+       /*
+        * Clear the online state before clearing memcg_caches array
+        * entries. The slab_mutex in memcg_deactivate_kmem_caches()
+        * guarantees that no cache will be created for this cgroup
+        * after we are done (see memcg_create_kmem_cache()).
+        */
+       memcg->kmem_state = KMEM_ALLOCATED;
+
+       memcg_deactivate_kmem_caches(memcg);
+
+       kmemcg_id = memcg->kmemcg_id;
+       BUG_ON(kmemcg_id < 0);
 
+       parent = parent_mem_cgroup(memcg);
        if (!parent)
-               return 0;
+               parent = root_mem_cgroup;
 
-       mutex_lock(&memcg_limit_mutex);
        /*
-        * If the parent cgroup is not kmem-active now, it cannot be activated
-        * after this point, because it has at least one child already.
+        * Change kmemcg_id of this cgroup and all its descendants to the
+        * parent's id, and then move all entries from this cgroup's list_lrus
+        * to ones of the parent. After we have finished, all list_lrus
+        * corresponding to this cgroup are guaranteed to remain empty. The
+        * ordering is imposed by list_lru_node->lock taken by
+        * memcg_drain_all_list_lrus().
         */
-       if (memcg_kmem_is_active(parent))
-               ret = memcg_activate_kmem(memcg, PAGE_COUNTER_MAX);
-       mutex_unlock(&memcg_limit_mutex);
-       return ret;
+       css_for_each_descendant_pre(css, &memcg->css) {
+               child = mem_cgroup_from_css(css);
+               BUG_ON(child->kmemcg_id != kmemcg_id);
+               child->kmemcg_id = parent->kmemcg_id;
+               if (!memcg->use_hierarchy)
+                       break;
+       }
+       memcg_drain_all_list_lrus(kmemcg_id, parent->kmemcg_id);
+
+       memcg_free_cache_id(kmemcg_id);
+}
+
+static void memcg_free_kmem(struct mem_cgroup *memcg)
+{
+       /* css_alloc() failed, offlining didn't happen */
+       if (unlikely(memcg->kmem_state == KMEM_ONLINE))
+               memcg_offline_kmem(memcg);
+
+       if (memcg->kmem_state == KMEM_ALLOCATED) {
+               memcg_destroy_kmem_caches(memcg);
+               static_branch_dec(&memcg_kmem_enabled_key);
+               WARN_ON(page_counter_read(&memcg->kmem));
+       }
 }
 #else
+static int memcg_propagate_kmem(struct mem_cgroup *parent, struct mem_cgroup *memcg)
+{
+       return 0;
+}
+static int memcg_online_kmem(struct mem_cgroup *memcg)
+{
+       return 0;
+}
+static void memcg_offline_kmem(struct mem_cgroup *memcg)
+{
+}
+static void memcg_free_kmem(struct mem_cgroup *memcg)
+{
+}
+#endif /* !CONFIG_SLOB */
+
 static int memcg_update_kmem_limit(struct mem_cgroup *memcg,
                                   unsigned long limit)
 {
-       return -EINVAL;
+       int ret = 0;
+
+       mutex_lock(&memcg_limit_mutex);
+       /* Top-level cgroup doesn't propagate from root */
+       if (!memcg_kmem_online(memcg)) {
+               if (cgroup_is_populated(memcg->css.cgroup) ||
+                   (memcg->use_hierarchy && memcg_has_children(memcg)))
+                       ret = -EBUSY;
+               if (ret)
+                       goto out;
+               ret = memcg_online_kmem(memcg);
+               if (ret)
+                       goto out;
+       }
+       ret = page_counter_limit(&memcg->kmem, limit);
+out:
+       mutex_unlock(&memcg_limit_mutex);
+       return ret;
+}
+
+static int memcg_update_tcp_limit(struct mem_cgroup *memcg, unsigned long limit)
+{
+       int ret;
+
+       mutex_lock(&memcg_limit_mutex);
+
+       ret = page_counter_limit(&memcg->tcpmem, limit);
+       if (ret)
+               goto out;
+
+       if (!memcg->tcpmem_active) {
+               /*
+                * The active flag needs to be written after the static_key
+                * update. This is what guarantees that the socket activation
+                * function is the last one to run. See sock_update_memcg() for
+                * details, and note that we don't mark any socket as belonging
+                * to this memcg until that flag is up.
+                *
+                * We need to do this, because static_keys will span multiple
+                * sites, but we can't control their order. If we mark a socket
+                * as accounted, but the accounting functions are not patched in
+                * yet, we'll lose accounting.
+                *
+                * We never race with the readers in sock_update_memcg(),
+                * because when this value change, the code to process it is not
+                * patched in yet.
+                */
+               static_branch_inc(&memcg_sockets_enabled_key);
+               memcg->tcpmem_active = true;
+       }
+out:
+       mutex_unlock(&memcg_limit_mutex);
+       return ret;
 }
-#endif /* CONFIG_MEMCG_KMEM */
 
 /*
  * The user of this function is...
@@ -3021,6 +3043,9 @@ static ssize_t mem_cgroup_write(struct kernfs_open_file *of,
                case _KMEM:
                        ret = memcg_update_kmem_limit(memcg, nr_pages);
                        break;
+               case _TCP:
+                       ret = memcg_update_tcp_limit(memcg, nr_pages);
+                       break;
                }
                break;
        case RES_SOFT_LIMIT:
@@ -3047,6 +3072,9 @@ static ssize_t mem_cgroup_reset(struct kernfs_open_file *of, char *buf,
        case _KMEM:
                counter = &memcg->kmem;
                break;
+       case _TCP:
+               counter = &memcg->tcpmem;
+               break;
        default:
                BUG();
        }
@@ -3162,7 +3190,7 @@ static int memcg_stat_show(struct seq_file *m, void *v)
        BUILD_BUG_ON(ARRAY_SIZE(mem_cgroup_lru_names) != NR_LRU_LISTS);
 
        for (i = 0; i < MEM_CGROUP_STAT_NSTATS; i++) {
-               if (i == MEM_CGROUP_STAT_SWAP && !do_swap_account)
+               if (i == MEM_CGROUP_STAT_SWAP && !do_memsw_account())
                        continue;
                seq_printf(m, "%s %lu\n", mem_cgroup_stat_names[i],
                           mem_cgroup_read_stat(memcg, i) * PAGE_SIZE);
@@ -3184,14 +3212,14 @@ static int memcg_stat_show(struct seq_file *m, void *v)
        }
        seq_printf(m, "hierarchical_memory_limit %llu\n",
                   (u64)memory * PAGE_SIZE);
-       if (do_swap_account)
+       if (do_memsw_account())
                seq_printf(m, "hierarchical_memsw_limit %llu\n",
                           (u64)memsw * PAGE_SIZE);
 
        for (i = 0; i < MEM_CGROUP_STAT_NSTATS; i++) {
                unsigned long long val = 0;
 
-               if (i == MEM_CGROUP_STAT_SWAP && !do_swap_account)
+               if (i == MEM_CGROUP_STAT_SWAP && !do_memsw_account())
                        continue;
                for_each_mem_cgroup_tree(mi, memcg)
                        val += mem_cgroup_read_stat(mi, i) * PAGE_SIZE;
@@ -3322,7 +3350,7 @@ static void mem_cgroup_threshold(struct mem_cgroup *memcg)
 {
        while (memcg) {
                __mem_cgroup_threshold(memcg, false);
-               if (do_swap_account)
+               if (do_memsw_account())
                        __mem_cgroup_threshold(memcg, true);
 
                memcg = parent_mem_cgroup(memcg);
@@ -3522,16 +3550,17 @@ static void __mem_cgroup_usage_unregister_event(struct mem_cgroup *memcg,
 swap_buffers:
        /* Swap primary and spare array */
        thresholds->spare = thresholds->primary;
-       /* If all events are unregistered, free the spare array */
-       if (!new) {
-               kfree(thresholds->spare);
-               thresholds->spare = NULL;
-       }
 
        rcu_assign_pointer(thresholds->primary, new);
 
        /* To be sure that nobody uses thresholds */
        synchronize_rcu();
+
+       /* If all events are unregistered, free the spare array */
+       if (!new) {
+               kfree(thresholds->spare);
+               thresholds->spare = NULL;
+       }
 unlock:
        mutex_unlock(&memcg->thresholds_lock);
 }
@@ -3612,119 +3641,37 @@ static int mem_cgroup_oom_control_write(struct cgroup_subsys_state *css,
        return 0;
 }
 
-#ifdef CONFIG_MEMCG_KMEM
-static int memcg_init_kmem(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
-{
-       int ret;
-
-       ret = memcg_propagate_kmem(memcg);
-       if (ret)
-               return ret;
+#ifdef CONFIG_CGROUP_WRITEBACK
 
-       return mem_cgroup_sockets_init(memcg, ss);
+struct list_head *mem_cgroup_cgwb_list(struct mem_cgroup *memcg)
+{
+       return &memcg->cgwb_list;
 }
 
-static void memcg_deactivate_kmem(struct mem_cgroup *memcg)
+static int memcg_wb_domain_init(struct mem_cgroup *memcg, gfp_t gfp)
 {
-       struct cgroup_subsys_state *css;
-       struct mem_cgroup *parent, *child;
-       int kmemcg_id;
+       return wb_domain_init(&memcg->cgwb_domain, gfp);
+}
 
-       if (!memcg->kmem_acct_active)
-               return;
+static void memcg_wb_domain_exit(struct mem_cgroup *memcg)
+{
+       wb_domain_exit(&memcg->cgwb_domain);
+}
 
-       /*
-        * Clear the 'active' flag before clearing memcg_caches arrays entries.
-        * Since we take the slab_mutex in memcg_deactivate_kmem_caches(), it
-        * guarantees no cache will be created for this cgroup after we are
-        * done (see memcg_create_kmem_cache()).
-        */
-       memcg->kmem_acct_active = false;
+static void memcg_wb_domain_size_changed(struct mem_cgroup *memcg)
+{
+       wb_domain_size_changed(&memcg->cgwb_domain);
+}
 
-       memcg_deactivate_kmem_caches(memcg);
+struct wb_domain *mem_cgroup_wb_domain(struct bdi_writeback *wb)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(wb->memcg_css);
 
-       kmemcg_id = memcg->kmemcg_id;
-       BUG_ON(kmemcg_id < 0);
+       if (!memcg->css.parent)
+               return NULL;
 
-       parent = parent_mem_cgroup(memcg);
-       if (!parent)
-               parent = root_mem_cgroup;
-
-       /*
-        * Change kmemcg_id of this cgroup and all its descendants to the
-        * parent's id, and then move all entries from this cgroup's list_lrus
-        * to ones of the parent. After we have finished, all list_lrus
-        * corresponding to this cgroup are guaranteed to remain empty. The
-        * ordering is imposed by list_lru_node->lock taken by
-        * memcg_drain_all_list_lrus().
-        */
-       css_for_each_descendant_pre(css, &memcg->css) {
-               child = mem_cgroup_from_css(css);
-               BUG_ON(child->kmemcg_id != kmemcg_id);
-               child->kmemcg_id = parent->kmemcg_id;
-               if (!memcg->use_hierarchy)
-                       break;
-       }
-       memcg_drain_all_list_lrus(kmemcg_id, parent->kmemcg_id);
-
-       memcg_free_cache_id(kmemcg_id);
-}
-
-static void memcg_destroy_kmem(struct mem_cgroup *memcg)
-{
-       if (memcg->kmem_acct_activated) {
-               memcg_destroy_kmem_caches(memcg);
-               static_key_slow_dec(&memcg_kmem_enabled_key);
-               WARN_ON(page_counter_read(&memcg->kmem));
-       }
-       mem_cgroup_sockets_destroy(memcg);
-}
-#else
-static int memcg_init_kmem(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
-{
-       return 0;
-}
-
-static void memcg_deactivate_kmem(struct mem_cgroup *memcg)
-{
-}
-
-static void memcg_destroy_kmem(struct mem_cgroup *memcg)
-{
-}
-#endif
-
-#ifdef CONFIG_CGROUP_WRITEBACK
-
-struct list_head *mem_cgroup_cgwb_list(struct mem_cgroup *memcg)
-{
-       return &memcg->cgwb_list;
-}
-
-static int memcg_wb_domain_init(struct mem_cgroup *memcg, gfp_t gfp)
-{
-       return wb_domain_init(&memcg->cgwb_domain, gfp);
-}
-
-static void memcg_wb_domain_exit(struct mem_cgroup *memcg)
-{
-       wb_domain_exit(&memcg->cgwb_domain);
-}
-
-static void memcg_wb_domain_size_changed(struct mem_cgroup *memcg)
-{
-       wb_domain_size_changed(&memcg->cgwb_domain);
-}
-
-struct wb_domain *mem_cgroup_wb_domain(struct bdi_writeback *wb)
-{
-       struct mem_cgroup *memcg = mem_cgroup_from_css(wb->memcg_css);
-
-       if (!memcg->css.parent)
-               return NULL;
-
-       return &memcg->cgwb_domain;
-}
+       return &memcg->cgwb_domain;
+}
 
 /**
  * mem_cgroup_wb_stats - retrieve writeback related stats from its memcg
@@ -4081,7 +4028,6 @@ static struct cftype mem_cgroup_legacy_files[] = {
                .seq_show = memcg_numa_stat_show,
        },
 #endif
-#ifdef CONFIG_MEMCG_KMEM
        {
                .name = "kmem.limit_in_bytes",
                .private = MEMFILE_PRIVATE(_KMEM, RES_LIMIT),
@@ -4114,7 +4060,29 @@ static struct cftype mem_cgroup_legacy_files[] = {
                .seq_show = memcg_slab_show,
        },
 #endif
-#endif
+       {
+               .name = "kmem.tcp.limit_in_bytes",
+               .private = MEMFILE_PRIVATE(_TCP, RES_LIMIT),
+               .write = mem_cgroup_write,
+               .read_u64 = mem_cgroup_read_u64,
+       },
+       {
+               .name = "kmem.tcp.usage_in_bytes",
+               .private = MEMFILE_PRIVATE(_TCP, RES_USAGE),
+               .read_u64 = mem_cgroup_read_u64,
+       },
+       {
+               .name = "kmem.tcp.failcnt",
+               .private = MEMFILE_PRIVATE(_TCP, RES_FAILCNT),
+               .write = mem_cgroup_reset,
+               .read_u64 = mem_cgroup_read_u64,
+       },
+       {
+               .name = "kmem.tcp.max_usage_in_bytes",
+               .private = MEMFILE_PRIVATE(_TCP, RES_MAX_USAGE),
+               .write = mem_cgroup_reset,
+               .read_u64 = mem_cgroup_read_u64,
+       },
        { },    /* terminate */
 };
 
@@ -4153,153 +4121,92 @@ static void free_mem_cgroup_per_zone_info(struct mem_cgroup *memcg, int node)
        kfree(memcg->nodeinfo[node]);
 }
 
-static struct mem_cgroup *mem_cgroup_alloc(void)
-{
-       struct mem_cgroup *memcg;
-       size_t size;
-
-       size = sizeof(struct mem_cgroup);
-       size += nr_node_ids * sizeof(struct mem_cgroup_per_node *);
-
-       memcg = kzalloc(size, GFP_KERNEL);
-       if (!memcg)
-               return NULL;
-
-       memcg->stat = alloc_percpu(struct mem_cgroup_stat_cpu);
-       if (!memcg->stat)
-               goto out_free;
-
-       if (memcg_wb_domain_init(memcg, GFP_KERNEL))
-               goto out_free_stat;
-
-       return memcg;
-
-out_free_stat:
-       free_percpu(memcg->stat);
-out_free:
-       kfree(memcg);
-       return NULL;
-}
-
-/*
- * At destroying mem_cgroup, references from swap_cgroup can remain.
- * (scanning all at force_empty is too costly...)
- *
- * Instead of clearing all references at force_empty, we remember
- * the number of reference from swap_cgroup and free mem_cgroup when
- * it goes down to 0.
- *
- * Removal of cgroup itself succeeds regardless of refs from swap.
- */
-
-static void __mem_cgroup_free(struct mem_cgroup *memcg)
+static void mem_cgroup_free(struct mem_cgroup *memcg)
 {
        int node;
 
-       mem_cgroup_remove_from_trees(memcg);
-
+       memcg_wb_domain_exit(memcg);
        for_each_node(node)
                free_mem_cgroup_per_zone_info(memcg, node);
-
        free_percpu(memcg->stat);
-       memcg_wb_domain_exit(memcg);
        kfree(memcg);
 }
 
-/*
- * Returns the parent mem_cgroup in memcgroup hierarchy with hierarchy enabled.
- */
-struct mem_cgroup *parent_mem_cgroup(struct mem_cgroup *memcg)
-{
-       if (!memcg->memory.parent)
-               return NULL;
-       return mem_cgroup_from_counter(memcg->memory.parent, memory);
-}
-EXPORT_SYMBOL(parent_mem_cgroup);
-
-static struct cgroup_subsys_state * __ref
-mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
+static struct mem_cgroup *mem_cgroup_alloc(void)
 {
        struct mem_cgroup *memcg;
-       long error = -ENOMEM;
+       size_t size;
        int node;
 
-       memcg = mem_cgroup_alloc();
+       size = sizeof(struct mem_cgroup);
+       size += nr_node_ids * sizeof(struct mem_cgroup_per_node *);
+
+       memcg = kzalloc(size, GFP_KERNEL);
        if (!memcg)
-               return ERR_PTR(error);
+               return NULL;
+
+       memcg->stat = alloc_percpu(struct mem_cgroup_stat_cpu);
+       if (!memcg->stat)
+               goto fail;
 
        for_each_node(node)
                if (alloc_mem_cgroup_per_zone_info(memcg, node))
-                       goto free_out;
+                       goto fail;
 
-       /* root ? */
-       if (parent_css == NULL) {
-               root_mem_cgroup = memcg;
-               mem_cgroup_root_css = &memcg->css;
-               page_counter_init(&memcg->memory, NULL);
-               memcg->high = PAGE_COUNTER_MAX;
-               memcg->soft_limit = PAGE_COUNTER_MAX;
-               page_counter_init(&memcg->memsw, NULL);
-               page_counter_init(&memcg->kmem, NULL);
-       }
+       if (memcg_wb_domain_init(memcg, GFP_KERNEL))
+               goto fail;
 
+       INIT_WORK(&memcg->high_work, high_work_func);
        memcg->last_scanned_node = MAX_NUMNODES;
        INIT_LIST_HEAD(&memcg->oom_notify);
-       memcg->move_charge_at_immigrate = 0;
        mutex_init(&memcg->thresholds_lock);
        spin_lock_init(&memcg->move_lock);
        vmpressure_init(&memcg->vmpressure);
        INIT_LIST_HEAD(&memcg->event_list);
        spin_lock_init(&memcg->event_list_lock);
-#ifdef CONFIG_MEMCG_KMEM
+       memcg->socket_pressure = jiffies;
+#ifndef CONFIG_SLOB
        memcg->kmemcg_id = -1;
 #endif
 #ifdef CONFIG_CGROUP_WRITEBACK
        INIT_LIST_HEAD(&memcg->cgwb_list);
 #endif
-       return &memcg->css;
-
-free_out:
-       __mem_cgroup_free(memcg);
-       return ERR_PTR(error);
+       return memcg;
+fail:
+       mem_cgroup_free(memcg);
+       return NULL;
 }
 
-static int
-mem_cgroup_css_online(struct cgroup_subsys_state *css)
+static struct cgroup_subsys_state * __ref
+mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
 {
-       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
-       struct mem_cgroup *parent = mem_cgroup_from_css(css->parent);
-       int ret;
-
-       if (css->id > MEM_CGROUP_ID_MAX)
-               return -ENOSPC;
-
-       if (!parent)
-               return 0;
-
-       mutex_lock(&memcg_create_mutex);
+       struct mem_cgroup *parent = mem_cgroup_from_css(parent_css);
+       struct mem_cgroup *memcg;
+       long error = -ENOMEM;
 
-       memcg->use_hierarchy = parent->use_hierarchy;
-       memcg->oom_kill_disable = parent->oom_kill_disable;
-       memcg->swappiness = mem_cgroup_swappiness(parent);
+       memcg = mem_cgroup_alloc();
+       if (!memcg)
+               return ERR_PTR(error);
 
-       if (parent->use_hierarchy) {
+       memcg->high = PAGE_COUNTER_MAX;
+       memcg->soft_limit = PAGE_COUNTER_MAX;
+       if (parent) {
+               memcg->swappiness = mem_cgroup_swappiness(parent);
+               memcg->oom_kill_disable = parent->oom_kill_disable;
+       }
+       if (parent && parent->use_hierarchy) {
+               memcg->use_hierarchy = true;
                page_counter_init(&memcg->memory, &parent->memory);
-               memcg->high = PAGE_COUNTER_MAX;
-               memcg->soft_limit = PAGE_COUNTER_MAX;
+               page_counter_init(&memcg->swap, &parent->swap);
                page_counter_init(&memcg->memsw, &parent->memsw);
                page_counter_init(&memcg->kmem, &parent->kmem);
-
-               /*
-                * No need to take a reference to the parent because cgroup
-                * core guarantees its existence.
-                */
+               page_counter_init(&memcg->tcpmem, &parent->tcpmem);
        } else {
                page_counter_init(&memcg->memory, NULL);
-               memcg->high = PAGE_COUNTER_MAX;
-               memcg->soft_limit = PAGE_COUNTER_MAX;
+               page_counter_init(&memcg->swap, NULL);
                page_counter_init(&memcg->memsw, NULL);
                page_counter_init(&memcg->kmem, NULL);
+               page_counter_init(&memcg->tcpmem, NULL);
                /*
                 * Deeper hierachy with use_hierarchy == false doesn't make
                 * much sense so let cgroup subsystem know about this
@@ -4308,18 +4215,31 @@ mem_cgroup_css_online(struct cgroup_subsys_state *css)
                if (parent != root_mem_cgroup)
                        memory_cgrp_subsys.broken_hierarchy = true;
        }
-       mutex_unlock(&memcg_create_mutex);
 
-       ret = memcg_init_kmem(memcg, &memory_cgrp_subsys);
-       if (ret)
-               return ret;
+       /* The following stuff does not apply to the root */
+       if (!parent) {
+               root_mem_cgroup = memcg;
+               return &memcg->css;
+       }
 
-       /*
-        * Make sure the memcg is initialized: mem_cgroup_iter()
-        * orders reading memcg->initialized against its callers
-        * reading the memcg members.
-        */
-       smp_store_release(&memcg->initialized, 1);
+       error = memcg_propagate_kmem(parent, memcg);
+       if (error)
+               goto fail;
+
+       if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
+               static_branch_inc(&memcg_sockets_enabled_key);
+
+       return &memcg->css;
+fail:
+       mem_cgroup_free(memcg);
+       return NULL;
+}
+
+static int
+mem_cgroup_css_online(struct cgroup_subsys_state *css)
+{
+       if (css->id > MEM_CGROUP_ID_MAX)
+               return -ENOSPC;
 
        return 0;
 }
@@ -4341,10 +4261,7 @@ static void mem_cgroup_css_offline(struct cgroup_subsys_state *css)
        }
        spin_unlock(&memcg->event_list_lock);
 
-       vmpressure_cleanup(&memcg->vmpressure);
-
-       memcg_deactivate_kmem(memcg);
-
+       memcg_offline_kmem(memcg);
        wb_memcg_offline(memcg);
 }
 
@@ -4359,8 +4276,17 @@ static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
-       memcg_destroy_kmem(memcg);
-       __mem_cgroup_free(memcg);
+       if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
+               static_branch_dec(&memcg_sockets_enabled_key);
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && memcg->tcpmem_active)
+               static_branch_dec(&memcg_sockets_enabled_key);
+
+       vmpressure_cleanup(&memcg->vmpressure);
+       cancel_work_sync(&memcg->high_work);
+       mem_cgroup_remove_from_trees(memcg);
+       memcg_free_kmem(memcg);
+       mem_cgroup_free(memcg);
 }
 
 /**
@@ -4476,7 +4402,7 @@ static struct page *mc_handle_swap_pte(struct vm_area_struct *vma,
         * we call find_get_page() with swapper_space directly.
         */
        page = find_get_page(swap_address_space(ent), ent.val);
-       if (do_swap_account)
+       if (do_memsw_account())
                entry->val = ent.val;
 
        return page;
@@ -4511,7 +4437,7 @@ static struct page *mc_handle_file_pte(struct vm_area_struct *vma,
                page = find_get_entry(mapping, pgoff);
                if (radix_tree_exceptional_entry(page)) {
                        swp_entry_t swp = radix_to_swp_entry(page);
-                       if (do_swap_account)
+                       if (do_memsw_account())
                                *entry = swp;
                        page = find_get_page(swap_address_space(swp), swp.val);
                }
@@ -4530,38 +4456,30 @@ static struct page *mc_handle_file_pte(struct vm_area_struct *vma,
  * @from: mem_cgroup which the page is moved from.
  * @to:        mem_cgroup which the page is moved to. @from != @to.
  *
- * The caller must confirm following.
- * - page is not on LRU (isolate_page() is useful.)
- * - compound_lock is held when nr_pages > 1
+ * The caller must make sure the page is not on LRU (isolate_page() is useful.)
  *
  * This function doesn't do "charge" to new cgroup and doesn't do "uncharge"
  * from old cgroup.
  */
 static int mem_cgroup_move_account(struct page *page,
-                                  unsigned int nr_pages,
+                                  bool compound,
                                   struct mem_cgroup *from,
                                   struct mem_cgroup *to)
 {
        unsigned long flags;
+       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
        int ret;
        bool anon;
 
        VM_BUG_ON(from == to);
        VM_BUG_ON_PAGE(PageLRU(page), page);
-       /*
-        * The page is isolated from LRU. So, collapse function
-        * will not handle this page. But page splitting can happen.
-        * Do this check under compound_page_lock(). The caller should
-        * hold it.
-        */
-       ret = -EBUSY;
-       if (nr_pages > 1 && !PageTransHuge(page))
-               goto out;
+       VM_BUG_ON(compound && !PageTransHuge(page));
 
        /*
         * Prevent mem_cgroup_replace_page() from looking at
         * page->mem_cgroup of its source page while we change it.
         */
+       ret = -EBUSY;
        if (!trylock_page(page))
                goto out;
 
@@ -4616,9 +4534,9 @@ static int mem_cgroup_move_account(struct page *page,
        ret = 0;
 
        local_irq_disable();
-       mem_cgroup_charge_statistics(to, page, nr_pages);
+       mem_cgroup_charge_statistics(to, page, compound, nr_pages);
        memcg_check_events(to, page);
-       mem_cgroup_charge_statistics(from, page, -nr_pages);
+       mem_cgroup_charge_statistics(from, page, compound, -nr_pages);
        memcg_check_events(from, page);
        local_irq_enable();
 out_unlock:
@@ -4708,7 +4626,7 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
        pte_t *pte;
        spinlock_t *ptl;
 
-       if (pmd_trans_huge_lock(pmd, vma, &ptl) == 1) {
+       if (pmd_trans_huge_lock(pmd, vma, &ptl)) {
                if (get_mctgt_type_thp(vma, addr, *pmd, NULL) == MC_TARGET_PAGE)
                        mc.precharge += HPAGE_PMD_NR;
                spin_unlock(ptl);
@@ -4813,7 +4731,7 @@ static void mem_cgroup_clear_mc(void)
 static int mem_cgroup_can_attach(struct cgroup_taskset *tset)
 {
        struct cgroup_subsys_state *css;
-       struct mem_cgroup *memcg;
+       struct mem_cgroup *memcg = NULL; /* unneeded init to make gcc happy */
        struct mem_cgroup *from;
        struct task_struct *leader, *p;
        struct mm_struct *mm;
@@ -4896,17 +4814,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
        union mc_target target;
        struct page *page;
 
-       /*
-        * We don't take compound_lock() here but no race with splitting thp
-        * happens because:
-        *  - if pmd_trans_huge_lock() returns 1, the relevant thp is not
-        *    under splitting, which means there's no concurrent thp split,
-        *  - if another thread runs into split_huge_page() just after we
-        *    entered this if-block, the thread must wait for page table lock
-        *    to be unlocked in __split_huge_page_splitting(), where the main
-        *    part of thp split is not executed yet.
-        */
-       if (pmd_trans_huge_lock(pmd, vma, &ptl) == 1) {
+       if (pmd_trans_huge_lock(pmd, vma, &ptl)) {
                if (mc.precharge < HPAGE_PMD_NR) {
                        spin_unlock(ptl);
                        return 0;
@@ -4915,7 +4823,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
                if (target_type == MC_TARGET_PAGE) {
                        page = target.page;
                        if (!isolate_lru_page(page)) {
-                               if (!mem_cgroup_move_account(page, HPAGE_PMD_NR,
+                               if (!mem_cgroup_move_account(page, true,
                                                             mc.from, mc.to)) {
                                        mc.precharge -= HPAGE_PMD_NR;
                                        mc.moved_charge += HPAGE_PMD_NR;
@@ -4942,9 +4850,18 @@ retry:
                switch (get_mctgt_type(vma, addr, ptent, &target)) {
                case MC_TARGET_PAGE:
                        page = target.page;
+                       /*
+                        * We can have a part of the split pmd here. Moving it
+                        * can be done but it would be too convoluted so simply
+                        * ignore such a partial THP and keep it in original
+                        * memcg. There should be somebody mapping the head.
+                        */
+                       if (PageTransCompound(page))
+                               goto put;
                        if (isolate_lru_page(page))
                                goto put;
-                       if (!mem_cgroup_move_account(page, 1, mc.from, mc.to)) {
+                       if (!mem_cgroup_move_account(page, false,
+                                               mc.from, mc.to)) {
                                mc.precharge--;
                                /* we uncharge from mc.from later. */
                                mc.moved_charge++;
@@ -5283,10 +5200,11 @@ bool mem_cgroup_low(struct mem_cgroup *root, struct mem_cgroup *memcg)
  * with mem_cgroup_cancel_charge() in case page instantiation fails.
  */
 int mem_cgroup_try_charge(struct page *page, struct mm_struct *mm,
-                         gfp_t gfp_mask, struct mem_cgroup **memcgp)
+                         gfp_t gfp_mask, struct mem_cgroup **memcgp,
+                         bool compound)
 {
        struct mem_cgroup *memcg = NULL;
-       unsigned int nr_pages = 1;
+       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
        int ret = 0;
 
        if (mem_cgroup_disabled())
@@ -5316,11 +5234,6 @@ int mem_cgroup_try_charge(struct page *page, struct mm_struct *mm,
                }
        }
 
-       if (PageTransHuge(page)) {
-               nr_pages <<= compound_order(page);
-               VM_BUG_ON_PAGE(!PageTransHuge(page), page);
-       }
-
        if (!memcg)
                memcg = get_mem_cgroup_from_mm(mm);
 
@@ -5349,9 +5262,9 @@ out:
  * Use mem_cgroup_cancel_charge() to cancel the transaction instead.
  */
 void mem_cgroup_commit_charge(struct page *page, struct mem_cgroup *memcg,
-                             bool lrucare)
+                             bool lrucare, bool compound)
 {
-       unsigned int nr_pages = 1;
+       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
 
        VM_BUG_ON_PAGE(!page->mapping, page);
        VM_BUG_ON_PAGE(PageLRU(page) && !lrucare, page);
@@ -5368,17 +5281,12 @@ void mem_cgroup_commit_charge(struct page *page, struct mem_cgroup *memcg,
 
        commit_charge(page, memcg, lrucare);
 
-       if (PageTransHuge(page)) {
-               nr_pages <<= compound_order(page);
-               VM_BUG_ON_PAGE(!PageTransHuge(page), page);
-       }
-
        local_irq_disable();
-       mem_cgroup_charge_statistics(memcg, page, nr_pages);
+       mem_cgroup_charge_statistics(memcg, page, compound, nr_pages);
        memcg_check_events(memcg, page);
        local_irq_enable();
 
-       if (do_swap_account && PageSwapCache(page)) {
+       if (do_memsw_account() && PageSwapCache(page)) {
                swp_entry_t entry = { .val = page_private(page) };
                /*
                 * The swap entry might not get freed for a long time,
@@ -5396,9 +5304,10 @@ void mem_cgroup_commit_charge(struct page *page, struct mem_cgroup *memcg,
  *
  * Cancel a charge transaction started by mem_cgroup_try_charge().
  */
-void mem_cgroup_cancel_charge(struct page *page, struct mem_cgroup *memcg)
+void mem_cgroup_cancel_charge(struct page *page, struct mem_cgroup *memcg,
+               bool compound)
 {
-       unsigned int nr_pages = 1;
+       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
 
        if (mem_cgroup_disabled())
                return;
@@ -5410,11 +5319,6 @@ void mem_cgroup_cancel_charge(struct page *page, struct mem_cgroup *memcg)
        if (!memcg)
                return;
 
-       if (PageTransHuge(page)) {
-               nr_pages <<= compound_order(page);
-               VM_BUG_ON_PAGE(!PageTransHuge(page), page);
-       }
-
        cancel_charge(memcg, nr_pages);
 }
 
@@ -5427,7 +5331,7 @@ static void uncharge_batch(struct mem_cgroup *memcg, unsigned long pgpgout,
 
        if (!mem_cgroup_is_root(memcg)) {
                page_counter_uncharge(&memcg->memory, nr_pages);
-               if (do_swap_account)
+               if (do_memsw_account())
                        page_counter_uncharge(&memcg->memsw, nr_pages);
                memcg_oom_recover(memcg);
        }
@@ -5553,7 +5457,8 @@ void mem_cgroup_uncharge_list(struct list_head *page_list)
 void mem_cgroup_replace_page(struct page *oldpage, struct page *newpage)
 {
        struct mem_cgroup *memcg;
-       int isolated;
+       unsigned int nr_pages;
+       bool compound;
 
        VM_BUG_ON_PAGE(!PageLocked(oldpage), oldpage);
        VM_BUG_ON_PAGE(!PageLocked(newpage), newpage);
@@ -5573,13 +5478,130 @@ void mem_cgroup_replace_page(struct page *oldpage, struct page *newpage)
        if (!memcg)
                return;
 
-       lock_page_lru(oldpage, &isolated);
-       oldpage->mem_cgroup = NULL;
-       unlock_page_lru(oldpage, isolated);
+       /* Force-charge the new page. The old one will be freed soon */
+       compound = PageTransHuge(newpage);
+       nr_pages = compound ? hpage_nr_pages(newpage) : 1;
+
+       page_counter_charge(&memcg->memory, nr_pages);
+       if (do_memsw_account())
+               page_counter_charge(&memcg->memsw, nr_pages);
+       css_get_many(&memcg->css, nr_pages);
 
        commit_charge(newpage, memcg, true);
+
+       local_irq_disable();
+       mem_cgroup_charge_statistics(memcg, newpage, compound, nr_pages);
+       memcg_check_events(memcg, newpage);
+       local_irq_enable();
+}
+
+DEFINE_STATIC_KEY_FALSE(memcg_sockets_enabled_key);
+EXPORT_SYMBOL(memcg_sockets_enabled_key);
+
+void sock_update_memcg(struct sock *sk)
+{
+       struct mem_cgroup *memcg;
+
+       /* Socket cloning can throw us here with sk_cgrp already
+        * filled. It won't however, necessarily happen from
+        * process context. So the test for root memcg given
+        * the current task's memcg won't help us in this case.
+        *
+        * Respecting the original socket's memcg is a better
+        * decision in this case.
+        */
+       if (sk->sk_memcg) {
+               BUG_ON(mem_cgroup_is_root(sk->sk_memcg));
+               css_get(&sk->sk_memcg->css);
+               return;
+       }
+
+       rcu_read_lock();
+       memcg = mem_cgroup_from_task(current);
+       if (memcg == root_mem_cgroup)
+               goto out;
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && !memcg->tcpmem_active)
+               goto out;
+       if (css_tryget_online(&memcg->css))
+               sk->sk_memcg = memcg;
+out:
+       rcu_read_unlock();
+}
+EXPORT_SYMBOL(sock_update_memcg);
+
+void sock_release_memcg(struct sock *sk)
+{
+       WARN_ON(!sk->sk_memcg);
+       css_put(&sk->sk_memcg->css);
+}
+
+/**
+ * mem_cgroup_charge_skmem - charge socket memory
+ * @memcg: memcg to charge
+ * @nr_pages: number of pages to charge
+ *
+ * Charges @nr_pages to @memcg. Returns %true if the charge fit within
+ * @memcg's configured limit, %false if the charge had to be forced.
+ */
+bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
+{
+       gfp_t gfp_mask = GFP_KERNEL;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
+               struct page_counter *fail;
+
+               if (page_counter_try_charge(&memcg->tcpmem, nr_pages, &fail)) {
+                       memcg->tcpmem_pressure = 0;
+                       return true;
+               }
+               page_counter_charge(&memcg->tcpmem, nr_pages);
+               memcg->tcpmem_pressure = 1;
+               return false;
+       }
+
+       /* Don't block in the packet receive path */
+       if (in_softirq())
+               gfp_mask = GFP_NOWAIT;
+
+       if (try_charge(memcg, gfp_mask, nr_pages) == 0)
+               return true;
+
+       try_charge(memcg, gfp_mask|__GFP_NOFAIL, nr_pages);
+       return false;
 }
 
+/**
+ * mem_cgroup_uncharge_skmem - uncharge socket memory
+ * @memcg - memcg to uncharge
+ * @nr_pages - number of pages to uncharge
+ */
+void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
+{
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
+               page_counter_uncharge(&memcg->tcpmem, nr_pages);
+               return;
+       }
+
+       page_counter_uncharge(&memcg->memory, nr_pages);
+       css_put_many(&memcg->css, nr_pages);
+}
+
+static int __init cgroup_memory(char *s)
+{
+       char *token;
+
+       while ((token = strsep(&s, ",")) != NULL) {
+               if (!*token)
+                       continue;
+               if (!strcmp(token, "nosocket"))
+                       cgroup_memory_nosocket = true;
+               if (!strcmp(token, "nokmem"))
+                       cgroup_memory_nokmem = true;
+       }
+       return 0;
+}
+__setup("cgroup.memory=", cgroup_memory);
+
 /*
  * subsys_initcall() for memory controller.
  *
@@ -5635,7 +5657,7 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
        VM_BUG_ON_PAGE(PageLRU(page), page);
        VM_BUG_ON_PAGE(page_count(page), page);
 
-       if (!do_swap_account)
+       if (!do_memsw_account())
                return;
 
        memcg = page->mem_cgroup;
@@ -5660,15 +5682,51 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
         * only synchronisation we have for udpating the per-CPU variables.
         */
        VM_BUG_ON(!irqs_disabled());
-       mem_cgroup_charge_statistics(memcg, page, -1);
+       mem_cgroup_charge_statistics(memcg, page, false, -1);
        memcg_check_events(memcg, page);
 }
 
+/*
+ * mem_cgroup_try_charge_swap - try charging a swap entry
+ * @page: page being added to swap
+ * @entry: swap entry to charge
+ *
+ * Try to charge @entry to the memcg that @page belongs to.
+ *
+ * Returns 0 on success, -ENOMEM on failure.
+ */
+int mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
+{
+       struct mem_cgroup *memcg;
+       struct page_counter *counter;
+       unsigned short oldid;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) || !do_swap_account)
+               return 0;
+
+       memcg = page->mem_cgroup;
+
+       /* Readahead page, never charged */
+       if (!memcg)
+               return 0;
+
+       if (!mem_cgroup_is_root(memcg) &&
+           !page_counter_try_charge(&memcg->swap, 1, &counter))
+               return -ENOMEM;
+
+       oldid = swap_cgroup_record(entry, mem_cgroup_id(memcg));
+       VM_BUG_ON_PAGE(oldid, page);
+       mem_cgroup_swap_statistics(memcg, true);
+
+       css_get(&memcg->css);
+       return 0;
+}
+
 /**
  * mem_cgroup_uncharge_swap - uncharge a swap entry
  * @entry: swap entry to uncharge
  *
- * Drop the memsw charge associated with @entry.
+ * Drop the swap charge associated with @entry.
  */
 void mem_cgroup_uncharge_swap(swp_entry_t entry)
 {
@@ -5682,14 +5740,53 @@ void mem_cgroup_uncharge_swap(swp_entry_t entry)
        rcu_read_lock();
        memcg = mem_cgroup_from_id(id);
        if (memcg) {
-               if (!mem_cgroup_is_root(memcg))
-                       page_counter_uncharge(&memcg->memsw, 1);
+               if (!mem_cgroup_is_root(memcg)) {
+                       if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
+                               page_counter_uncharge(&memcg->swap, 1);
+                       else
+                               page_counter_uncharge(&memcg->memsw, 1);
+               }
                mem_cgroup_swap_statistics(memcg, false);
                css_put(&memcg->css);
        }
        rcu_read_unlock();
 }
 
+long mem_cgroup_get_nr_swap_pages(struct mem_cgroup *memcg)
+{
+       long nr_swap_pages = get_nr_swap_pages();
+
+       if (!do_swap_account || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               return nr_swap_pages;
+       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg))
+               nr_swap_pages = min_t(long, nr_swap_pages,
+                                     READ_ONCE(memcg->swap.limit) -
+                                     page_counter_read(&memcg->swap));
+       return nr_swap_pages;
+}
+
+bool mem_cgroup_swap_full(struct page *page)
+{
+       struct mem_cgroup *memcg;
+
+       VM_BUG_ON_PAGE(!PageLocked(page), page);
+
+       if (vm_swap_full())
+               return true;
+       if (!do_swap_account || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               return false;
+
+       memcg = page->mem_cgroup;
+       if (!memcg)
+               return false;
+
+       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg))
+               if (page_counter_read(&memcg->swap) * 2 >= memcg->swap.limit)
+                       return true;
+
+       return false;
+}
+
 /* for remember boot option*/
 #ifdef CONFIG_MEMCG_SWAP_ENABLED
 static int really_do_swap_account __initdata = 1;
@@ -5707,6 +5804,63 @@ static int __init enable_swap_account(char *s)
 }
 __setup("swapaccount=", enable_swap_account);
 
+static u64 swap_current_read(struct cgroup_subsys_state *css,
+                            struct cftype *cft)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+
+       return (u64)page_counter_read(&memcg->swap) * PAGE_SIZE;
+}
+
+static int swap_max_show(struct seq_file *m, void *v)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(m));
+       unsigned long max = READ_ONCE(memcg->swap.limit);
+
+       if (max == PAGE_COUNTER_MAX)
+               seq_puts(m, "max\n");
+       else
+               seq_printf(m, "%llu\n", (u64)max * PAGE_SIZE);
+
+       return 0;
+}
+
+static ssize_t swap_max_write(struct kernfs_open_file *of,
+                             char *buf, size_t nbytes, loff_t off)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+       unsigned long max;
+       int err;
+
+       buf = strstrip(buf);
+       err = page_counter_memparse(buf, "max", &max);
+       if (err)
+               return err;
+
+       mutex_lock(&memcg_limit_mutex);
+       err = page_counter_limit(&memcg->swap, max);
+       mutex_unlock(&memcg_limit_mutex);
+       if (err)
+               return err;
+
+       return nbytes;
+}
+
+static struct cftype swap_files[] = {
+       {
+               .name = "swap.current",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .read_u64 = swap_current_read,
+       },
+       {
+               .name = "swap.max",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .seq_show = swap_max_show,
+               .write = swap_max_write,
+       },
+       { }     /* terminate */
+};
+
 static struct cftype memsw_cgroup_files[] = {
        {
                .name = "memsw.usage_in_bytes",
@@ -5738,6 +5892,8 @@ static int __init mem_cgroup_swap_init(void)
 {
        if (!mem_cgroup_disabled() && really_do_swap_account) {
                do_swap_account = 1;
+               WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys,
+                                              swap_files));
                WARN_ON(cgroup_add_legacy_cftypes(&memory_cgrp_subsys,
                                                  memsw_cgroup_files));
        }
This page took 0.045944 seconds and 5 git commands to generate.