Add Abstract Base Class to WebElement and WebDriver #7127 (#8348)

* Add Abstract Base Class to WebElement and WebDriver

Co-authored-by: David Burns <david.burns@theautomatedtester.co.uk>
Cr-Mirrored-From: https://chromium.googlesource.com/external/github.com/SeleniumHQ/selenium
Cr-Mirrored-Commit: 0dc42efd8b9bce8f81a1e5a852f24b9caf98e885
diff --git a/selenium/webdriver/common/actions/pointer_actions.py b/selenium/webdriver/common/actions/pointer_actions.py
index 7d2fdc7..cc6033d 100644
--- a/selenium/webdriver/common/actions/pointer_actions.py
+++ b/selenium/webdriver/common/actions/pointer_actions.py
@@ -21,7 +21,6 @@
 from .pointer_input import PointerInput
 
 from selenium.webdriver.remote.webelement import WebElement
-from selenium.webdriver.support.event_firing_webdriver import EventFiringWebElement
 
 
 class PointerActions(Interaction):
@@ -39,7 +38,7 @@
         self._button_action("create_pointer_up", button=button)
 
     def move_to(self, element, x=None, y=None):
-        if not isinstance(element, (WebElement, EventFiringWebElement)):
+        if not isinstance(element, WebElement):
             raise AttributeError("move_to requires a WebElement")
         if x is not None or y is not None:
             el_rect = element.rect
diff --git a/selenium/webdriver/common/actions/pointer_input.py b/selenium/webdriver/common/actions/pointer_input.py
index 2216a5a..5ca74ab 100644
--- a/selenium/webdriver/common/actions/pointer_input.py
+++ b/selenium/webdriver/common/actions/pointer_input.py
@@ -19,7 +19,6 @@
 
 from selenium.common.exceptions import InvalidArgumentException
 from selenium.webdriver.remote.webelement import WebElement
-from selenium.webdriver.support.event_firing_webdriver import EventFiringWebElement
 
 
 class PointerInput(InputDevice):
@@ -28,7 +27,7 @@
 
     def __init__(self, kind, name):
         super(PointerInput, self).__init__()
-        if (kind not in POINTER_KINDS):
+        if kind not in POINTER_KINDS:
             raise InvalidArgumentException("Invalid PointerInput kind '%s'" % kind)
         self.type = POINTER
         self.kind = kind
@@ -38,7 +37,7 @@
         action = dict(type="pointerMove", duration=duration)
         action["x"] = x
         action["y"] = y
-        if isinstance(origin, (WebElement, EventFiringWebElement)):
+        if isinstance(origin, WebElement):
             action["origin"] = {"element-6066-11e4-a52e-4f735466cecf": origin.id}
         elif origin is not None:
             action["origin"] = origin
diff --git a/selenium/webdriver/firefox/webdriver.py b/selenium/webdriver/firefox/webdriver.py
index 01c3c3f..cb5e878 100644
--- a/selenium/webdriver/firefox/webdriver.py
+++ b/selenium/webdriver/firefox/webdriver.py
@@ -25,7 +25,7 @@
 from contextlib import contextmanager
 
 from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
-from selenium.webdriver.remote.webdriver import WebDriver as RemoteWebDriver
+from selenium.webdriver.remote.webdriver import WebDriver as RemoteWebDriver, WebElement
 
 from .firefox_binary import FirefoxBinary
 from .firefox_profile import FirefoxProfile
@@ -341,3 +341,17 @@
                 driver.get_full_page_screenshot_as_base64()
         """
         return self.execute("FULL_PAGE_SCREENSHOT")['value']
+
+    def _wrap_value(self, value):
+        """Overload the _wrap_value so that custom WebElement types can be checked against"""
+        if isinstance(value, dict):
+            converted = {}
+            for key, val in value.items():
+                converted[key] = self._wrap_value(val)
+            return converted
+        elif isinstance(value, (self._web_element_cls, WebElement)):
+            return {'ELEMENT': value.id, 'element-6066-11e4-a52e-4f735466cecf': value.id}
+        elif isinstance(value, list):
+            return list(self._wrap_value(item) for item in value)
+        else:
+            return value
diff --git a/selenium/webdriver/remote/webdriver.py b/selenium/webdriver/remote/webdriver.py
index 78cbec7..9a70dfc 100644
--- a/selenium/webdriver/remote/webdriver.py
+++ b/selenium/webdriver/remote/webdriver.py
@@ -17,6 +17,7 @@
 
 """The WebDriver implementation."""
 
+from abc import ABCMeta
 import base64
 import copy
 from contextlib import contextmanager
@@ -36,10 +37,8 @@
                                         NoSuchCookieException,
                                         UnknownMethodException)
 from selenium.webdriver.common.by import By
-from selenium.webdriver.common.html5.application_cache import ApplicationCache
-
 from selenium.webdriver.common.timeouts import Timeouts
-
+from selenium.webdriver.common.html5.application_cache import ApplicationCache
 from selenium.webdriver.support.relative_locator import RelativeBy
 
 try:
@@ -115,7 +114,17 @@
     return handler(command_executor, keep_alive=keep_alive)
 
 
-class WebDriver(object):
+class BaseWebDriver(object):
+    """
+    Abstract Base Class for all Webdriver subtypes.
+    ABC's allow custom implementations of Webdriver to be registered so that isinstance type checks
+    will succeed.
+    """
+    __metaclass__ = ABCMeta
+    # TODO: After dropping Python 2, use ABC instead of ABCMeta and remove all Python 2 metaclass declarations.
+
+
+class WebDriver(BaseWebDriver):
     """
     Controls a browser by sending commands to a remote server.
     This server is expected to be running the WebDriver wire protocol
@@ -162,7 +171,7 @@
             else:
                 capabilities.update(desired_capabilities)
         self.command_executor = command_executor
-        if type(self.command_executor) is bytes or isinstance(self.command_executor, str):
+        if isinstance(self.command_executor, (str, bytes)):
             self.command_executor = get_remote_connection(capabilities, command_executor=command_executor, keep_alive=keep_alive)
         self._is_remote = True
         self.session_id = None
diff --git a/selenium/webdriver/remote/webelement.py b/selenium/webdriver/remote/webelement.py
index f29b89b..b51970c 100644
--- a/selenium/webdriver/remote/webelement.py
+++ b/selenium/webdriver/remote/webelement.py
@@ -21,6 +21,7 @@
 import pkgutil
 import warnings
 import zipfile
+from abc import ABCMeta
 from io import BytesIO
 
 from selenium.common.exceptions import WebDriverException
@@ -47,7 +48,15 @@
 isDisplayed_js = pkgutil.get_data(_pkg, 'isDisplayed.js').decode('utf8')
 
 
-class WebElement(object):
+class BaseWebElement(object):
+    """
+    Abstract Base Class for WebElement.
+    ABC's will allow custom types to be registered as a WebElement to pass type checks.
+    """
+    __metaclass__ = ABCMeta
+
+
+class WebElement(BaseWebElement):
     """Represents a DOM element.
 
     Generally, all interesting operations that interact with a document will be
@@ -138,18 +147,18 @@
 
         """
 
-        attributeValue = ''
+        attribute_value = ''
         if self._w3c:
-            attributeValue = self.parent.execute_script(
+            attribute_value = self.parent.execute_script(
                 "return (%s).apply(null, arguments);" % getAttribute_js,
                 self, name)
         else:
             resp = self._execute(Command.GET_ELEMENT_ATTRIBUTE, {'name': name})
-            attributeValue = resp.get('value')
-            if attributeValue is not None:
-                if name != 'value' and attributeValue.lower() in ('true', 'false'):
-                    attributeValue = attributeValue.lower()
-        return attributeValue
+            attribute_value = resp.get('value')
+            if attribute_value is not None:
+                if name != 'value' and attribute_value.lower() in ('true', 'false'):
+                    attribute_value = attribute_value.lower()
+        return attribute_value
 
     def is_selected(self):
         """Returns whether the element is selected.
diff --git a/selenium/webdriver/support/event_firing_webdriver.py b/selenium/webdriver/support/event_firing_webdriver.py
index d5fc011..9f72feb 100644
--- a/selenium/webdriver/support/event_firing_webdriver.py
+++ b/selenium/webdriver/support/event_firing_webdriver.py
@@ -26,12 +26,12 @@
 
 
 def _wrap_elements(result, ef_driver):
-    if isinstance(result, WebElement):
+    if isinstance(result, EventFiringWebElement):
+        return result
+    elif isinstance(result, WebElement):
         return EventFiringWebElement(result, ef_driver)
     elif isinstance(result, list):
         return [_wrap_elements(item, ef_driver) for item in result]
-    else:
-        return result
 
 
 class EventFiringWebDriver(object):
@@ -353,3 +353,6 @@
         except Exception as e:
             self._listener.on_exception(e, self._driver)
             raise
+
+
+WebElement.register(EventFiringWebElement)
diff --git a/test/unit/selenium/webdriver/remote/test_subtyping.py b/test/unit/selenium/webdriver/remote/test_subtyping.py
new file mode 100644
index 0000000..5f646e1
--- /dev/null
+++ b/test/unit/selenium/webdriver/remote/test_subtyping.py
@@ -0,0 +1,55 @@
+#  Licensed to the Software Freedom Conservancy (SFC) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The SFC licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+
+from selenium.webdriver.remote.webdriver import WebElement
+from selenium.webdriver.remote.webdriver import WebDriver
+
+
+def test_web_element_not_subclassed():
+    """A registered subtype of WebElement should work with isinstance checks."""
+    class MyWebElement(object):
+        def __init__(self, parent, id, _w3c=True):
+            self.parent = parent
+            self.id = id
+            self._w3c = _w3c
+
+    # Test that non registered class instance is not instance of Remote WebElement
+    my_web_element = MyWebElement('parent', '1')
+    assert not isinstance(my_web_element, WebElement)
+
+    # Register the class as a subtype of WebElement
+    WebElement.register('MyWebElement')
+    my_registered_web_element = MyWebElement('parent', '2')
+
+    assert isinstance(my_registered_web_element, WebElement)
+
+
+def test_webdriver_not_subclassed():
+    """A registered subtype of WebDriver should work with isinstance checks."""
+    class MyWebDriver(object):
+        def __init__(self, *args, **kwargs):
+            super(MyWebDriver, self).__init__(*args, **kwargs)
+
+    # Test that non registered class instance is not instance of Remote WebDriver
+    my_driver = MyWebDriver()
+    assert not isinstance(my_driver, WebDriver)
+
+    # Register the class as a subtype of WebDriver
+    WebDriver.register(MyWebDriver)
+    my_registered_driver = MyWebDriver()
+
+    assert isinstance(my_registered_driver, MyWebDriver)