autoupdate.py: Allow max_updates through update URL

Currently in order to put a ceiling on the number of updates that can be
performed in a devserver, we need to spawn a new devserver with flag
--max_updates. The problem is now for each run of an autotest, they
should spwan a new devserver alongside the lab devservers which is not
ideal.

We solve this problem by dynamically configuring the devserver
using a unique identifier. This is done by calling into 'session_id' API
of the devserver. Then clients can send their requests appending this
session_id to be responded likewise.

For maximum updates, we first configure the devserver to set a
'max_updates' data for a unique session ID. Then client can send
requests using the session ID as a query string be capped on the number
of updates they get.

BUG=chromium:1004489
TEST=autoupdate_unittest.py
TEST=devserver_integration_test.py

Change-Id: Ieef921b177ba0ec789d6471a34a4f8e44f5482af
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/dev-util/+/1996148
Tested-by: Amin Hassani <ahassani@chromium.org>
Commit-Queue: Amin Hassani <ahassani@chromium.org>
Reviewed-by: Allen Li <ayatane@chromium.org>
diff --git a/autoupdate.py b/autoupdate.py
index a309a79..5e76467 100644
--- a/autoupdate.py
+++ b/autoupdate.py
@@ -7,6 +7,9 @@
 
 from __future__ import print_function
 
+import collections
+import contextlib
+import datetime
 import json
 import os
 import threading
@@ -113,6 +116,92 @@
     return self.table.get(host_id)
 
 
+class SessionTable(object):
+  """A class to keep a map of session IDs and data.
+
+  This can be used to set some configuration related to a session and
+  retrieve/manipulate the configuration whenever needed. This is basically a map
+  of string to a dict object.
+  """
+
+  SESSION_EXPIRATION_TIMEDIFF = datetime.timedelta(hours=1)
+  OCCASIONAL_PURGE_TIMEDIFF = datetime.timedelta(hours=1)
+
+  Session = collections.namedtuple('Session', ['timestamp', 'data'])
+
+  def __init__(self):
+    """Initializes the SessionTable class."""
+    self._table = {}
+    # Since multiple requests might come for this session table by multiple
+    # threads, keep it under a lock.
+    self._lock = threading.Lock()
+    self._last_purge_time = datetime.datetime.now()
+
+  def _ShouldPurge(self):
+    """Returns whether its time to do an occasional purge."""
+    return (datetime.datetime.now() - self._last_purge_time >
+            self.OCCASIONAL_PURGE_TIMEDIFF)
+
+  def _IsSessionExpired(self, session):
+    """Returns whether a session needs to be purged.
+
+    Args:
+      session: A unique identifer string for a session.
+    """
+    return (datetime.datetime.now() - session.timestamp >
+            self.SESSION_EXPIRATION_TIMEDIFF)
+
+  def _Purge(self):
+    """Cleans up entries that have been here long enough.
+
+    This is so the memory usage of devserver doesn't get bloated.
+    """
+    # Try to purge once every hour or so.
+    if not self._ShouldPurge():
+      return
+
+    # Purge the ones not in use.
+    self._table = {k: v for k, v in self._table.items()
+                   if not self._IsSessionExpired(v)}
+
+  def SetSessionData(self, session, data):
+    """Sets data for the given a session ID.
+
+    Args:
+      session: A unique identifier string.
+      data: A data to set for this session ID.
+    """
+    if not session or data is None:
+      return
+
+    with self._lock:
+      self._Purge()
+
+      if self._table.get(session) is not None:
+        _Log('Replacing an existing session %s', session)
+      self._table[session] = SessionTable.Session(datetime.datetime.now(), data)
+
+  @contextlib.contextmanager
+  def SessionData(self, session):
+    """Returns the session data for manipulation.
+
+    Args:
+      session: A unique identifier string.
+    """
+    # Cherrypy has multiple threads and this data structure is global, so lock
+    # it to restrict simultaneous access by multiple threads.
+    with self._lock:
+      session_value = self._table.get(session)
+      # If not in the table, just assume it wasn't supposed to be.
+      if session_value is None:
+        yield {}
+      else:
+        # To update the timestamp.
+        self._table[session] = SessionTable.Session(datetime.datetime.now(),
+                                                    session_value.data)
+        yield session_value.data
+
+
 class Autoupdate(build_util.BuildObject):
   """Class that contains functionality that handles Chrome OS update pings."""
 
@@ -146,6 +235,8 @@
     # host, as well as a dictionary of current attributes derived from events.
     self.host_infos = HostInfoTable()
 
+    self._session_table = SessionTable()
+
     self._update_count_lock = threading.Lock()
 
   def GetUpdateForLabel(self, label):
@@ -319,20 +410,32 @@
     request = nebraska.Request(data)
     self._LogRequest(request)
 
+    session = kwargs.get('session')
+    _Log('Requested session is: %s', session)
+
     if request.request_type == nebraska.Request.RequestType.EVENT:
       if (request.app_requests[0].event_type ==
           nebraska.Request.EVENT_TYPE_UPDATE_DOWNLOAD_STARTED and
           request.app_requests[0].event_result ==
           nebraska.Request.EVENT_RESULT_SUCCESS):
+        err_msg = ('Received too many download_started notifications. This '
+                   'probably means a bug in the test environment, such as too '
+                   'many clients running concurrently. Alternatively, it could '
+                   'be a bug in the update client.')
+
         with self._update_count_lock:
           if self.max_updates == 0:
-            _Log('Received too many download_started notifications. This '
-                 'probably means a bug in the test environment, such as too '
-                 'many clients running concurrently. Alternatively, it could '
-                 'be a bug in the update client.')
+            _Log(err_msg)
           elif self.max_updates > 0:
             self.max_updates -= 1
 
+        with self._session_table.SessionData(session) as session_data:
+          value = session_data.get('max_updates')
+          if value is not None:
+            session_data['max_updates'] = max(value - 1, 0)
+            if value == 0:
+              _Log(err_msg)
+
       _Log('A non-update event notification received. Returning an ack.')
       nebraska_obj = nebraska.Nebraska()
       return nebraska_obj.GetResponseToRequest(request)
@@ -341,7 +444,11 @@
     # responses. Note that the counter is only decremented when the client
     # reports an actual download, to avoid race conditions between concurrent
     # update requests from the same client due to a timeout.
-    if self.max_updates == 0:
+    max_updates = None
+    with self._session_table.SessionData(session) as session_data:
+      max_updates = session_data.get('max_updates')
+
+    if self.max_updates == 0 or max_updates == 0:
       _Log('Request received but max number of updates already served.')
       nebraska_obj = nebraska.Nebraska()
       response_props = nebraska.ResponseProperties(no_update=True)
@@ -391,3 +498,12 @@
 
     # If no events were logged for this IP, return an empty log.
     return json.dumps([])
+
+  def SetSessionData(self, session, data):
+    """Sets the session ID for the current run.
+
+    Args:
+      session: A unique identifier string.
+      data: A dictionary containing some data.
+    """
+    self._session_table.SetSessionData(session, data)