autoupdate.py: Allow max_updates through update URL
Currently in order to put a ceiling on the number of updates that can be
performed in a devserver, we need to spawn a new devserver with flag
--max_updates. The problem is now for each run of an autotest, they
should spwan a new devserver alongside the lab devservers which is not
ideal.
We solve this problem by dynamically configuring the devserver
using a unique identifier. This is done by calling into 'session_id' API
of the devserver. Then clients can send their requests appending this
session_id to be responded likewise.
For maximum updates, we first configure the devserver to set a
'max_updates' data for a unique session ID. Then client can send
requests using the session ID as a query string be capped on the number
of updates they get.
BUG=chromium:1004489
TEST=autoupdate_unittest.py
TEST=devserver_integration_test.py
Change-Id: Ieef921b177ba0ec789d6471a34a4f8e44f5482af
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/dev-util/+/1996148
Tested-by: Amin Hassani <ahassani@chromium.org>
Commit-Queue: Amin Hassani <ahassani@chromium.org>
Reviewed-by: Allen Li <ayatane@chromium.org>
diff --git a/autoupdate.py b/autoupdate.py
index a309a79..5e76467 100644
--- a/autoupdate.py
+++ b/autoupdate.py
@@ -7,6 +7,9 @@
from __future__ import print_function
+import collections
+import contextlib
+import datetime
import json
import os
import threading
@@ -113,6 +116,92 @@
return self.table.get(host_id)
+class SessionTable(object):
+ """A class to keep a map of session IDs and data.
+
+ This can be used to set some configuration related to a session and
+ retrieve/manipulate the configuration whenever needed. This is basically a map
+ of string to a dict object.
+ """
+
+ SESSION_EXPIRATION_TIMEDIFF = datetime.timedelta(hours=1)
+ OCCASIONAL_PURGE_TIMEDIFF = datetime.timedelta(hours=1)
+
+ Session = collections.namedtuple('Session', ['timestamp', 'data'])
+
+ def __init__(self):
+ """Initializes the SessionTable class."""
+ self._table = {}
+ # Since multiple requests might come for this session table by multiple
+ # threads, keep it under a lock.
+ self._lock = threading.Lock()
+ self._last_purge_time = datetime.datetime.now()
+
+ def _ShouldPurge(self):
+ """Returns whether its time to do an occasional purge."""
+ return (datetime.datetime.now() - self._last_purge_time >
+ self.OCCASIONAL_PURGE_TIMEDIFF)
+
+ def _IsSessionExpired(self, session):
+ """Returns whether a session needs to be purged.
+
+ Args:
+ session: A unique identifer string for a session.
+ """
+ return (datetime.datetime.now() - session.timestamp >
+ self.SESSION_EXPIRATION_TIMEDIFF)
+
+ def _Purge(self):
+ """Cleans up entries that have been here long enough.
+
+ This is so the memory usage of devserver doesn't get bloated.
+ """
+ # Try to purge once every hour or so.
+ if not self._ShouldPurge():
+ return
+
+ # Purge the ones not in use.
+ self._table = {k: v for k, v in self._table.items()
+ if not self._IsSessionExpired(v)}
+
+ def SetSessionData(self, session, data):
+ """Sets data for the given a session ID.
+
+ Args:
+ session: A unique identifier string.
+ data: A data to set for this session ID.
+ """
+ if not session or data is None:
+ return
+
+ with self._lock:
+ self._Purge()
+
+ if self._table.get(session) is not None:
+ _Log('Replacing an existing session %s', session)
+ self._table[session] = SessionTable.Session(datetime.datetime.now(), data)
+
+ @contextlib.contextmanager
+ def SessionData(self, session):
+ """Returns the session data for manipulation.
+
+ Args:
+ session: A unique identifier string.
+ """
+ # Cherrypy has multiple threads and this data structure is global, so lock
+ # it to restrict simultaneous access by multiple threads.
+ with self._lock:
+ session_value = self._table.get(session)
+ # If not in the table, just assume it wasn't supposed to be.
+ if session_value is None:
+ yield {}
+ else:
+ # To update the timestamp.
+ self._table[session] = SessionTable.Session(datetime.datetime.now(),
+ session_value.data)
+ yield session_value.data
+
+
class Autoupdate(build_util.BuildObject):
"""Class that contains functionality that handles Chrome OS update pings."""
@@ -146,6 +235,8 @@
# host, as well as a dictionary of current attributes derived from events.
self.host_infos = HostInfoTable()
+ self._session_table = SessionTable()
+
self._update_count_lock = threading.Lock()
def GetUpdateForLabel(self, label):
@@ -319,20 +410,32 @@
request = nebraska.Request(data)
self._LogRequest(request)
+ session = kwargs.get('session')
+ _Log('Requested session is: %s', session)
+
if request.request_type == nebraska.Request.RequestType.EVENT:
if (request.app_requests[0].event_type ==
nebraska.Request.EVENT_TYPE_UPDATE_DOWNLOAD_STARTED and
request.app_requests[0].event_result ==
nebraska.Request.EVENT_RESULT_SUCCESS):
+ err_msg = ('Received too many download_started notifications. This '
+ 'probably means a bug in the test environment, such as too '
+ 'many clients running concurrently. Alternatively, it could '
+ 'be a bug in the update client.')
+
with self._update_count_lock:
if self.max_updates == 0:
- _Log('Received too many download_started notifications. This '
- 'probably means a bug in the test environment, such as too '
- 'many clients running concurrently. Alternatively, it could '
- 'be a bug in the update client.')
+ _Log(err_msg)
elif self.max_updates > 0:
self.max_updates -= 1
+ with self._session_table.SessionData(session) as session_data:
+ value = session_data.get('max_updates')
+ if value is not None:
+ session_data['max_updates'] = max(value - 1, 0)
+ if value == 0:
+ _Log(err_msg)
+
_Log('A non-update event notification received. Returning an ack.')
nebraska_obj = nebraska.Nebraska()
return nebraska_obj.GetResponseToRequest(request)
@@ -341,7 +444,11 @@
# responses. Note that the counter is only decremented when the client
# reports an actual download, to avoid race conditions between concurrent
# update requests from the same client due to a timeout.
- if self.max_updates == 0:
+ max_updates = None
+ with self._session_table.SessionData(session) as session_data:
+ max_updates = session_data.get('max_updates')
+
+ if self.max_updates == 0 or max_updates == 0:
_Log('Request received but max number of updates already served.')
nebraska_obj = nebraska.Nebraska()
response_props = nebraska.ResponseProperties(no_update=True)
@@ -391,3 +498,12 @@
# If no events were logged for this IP, return an empty log.
return json.dumps([])
+
+ def SetSessionData(self, session, data):
+ """Sets the session ID for the current run.
+
+ Args:
+ session: A unique identifier string.
+ data: A dictionary containing some data.
+ """
+ self._session_table.SetSessionData(session, data)