Add Abstract Base Class to WebElement and WebDriver #7127 (#8348)
* Add Abstract Base Class to WebElement and WebDriver
Co-authored-by: David Burns <david.burns@theautomatedtester.co.uk>
Cr-Mirrored-From: https://chromium.googlesource.com/external/github.com/SeleniumHQ/selenium
Cr-Mirrored-Commit: 0dc42efd8b9bce8f81a1e5a852f24b9caf98e885
diff --git a/selenium/webdriver/common/actions/pointer_actions.py b/selenium/webdriver/common/actions/pointer_actions.py
index 7d2fdc7..cc6033d 100644
--- a/selenium/webdriver/common/actions/pointer_actions.py
+++ b/selenium/webdriver/common/actions/pointer_actions.py
@@ -21,7 +21,6 @@
from .pointer_input import PointerInput
from selenium.webdriver.remote.webelement import WebElement
-from selenium.webdriver.support.event_firing_webdriver import EventFiringWebElement
class PointerActions(Interaction):
@@ -39,7 +38,7 @@
self._button_action("create_pointer_up", button=button)
def move_to(self, element, x=None, y=None):
- if not isinstance(element, (WebElement, EventFiringWebElement)):
+ if not isinstance(element, WebElement):
raise AttributeError("move_to requires a WebElement")
if x is not None or y is not None:
el_rect = element.rect
diff --git a/selenium/webdriver/common/actions/pointer_input.py b/selenium/webdriver/common/actions/pointer_input.py
index 2216a5a..5ca74ab 100644
--- a/selenium/webdriver/common/actions/pointer_input.py
+++ b/selenium/webdriver/common/actions/pointer_input.py
@@ -19,7 +19,6 @@
from selenium.common.exceptions import InvalidArgumentException
from selenium.webdriver.remote.webelement import WebElement
-from selenium.webdriver.support.event_firing_webdriver import EventFiringWebElement
class PointerInput(InputDevice):
@@ -28,7 +27,7 @@
def __init__(self, kind, name):
super(PointerInput, self).__init__()
- if (kind not in POINTER_KINDS):
+ if kind not in POINTER_KINDS:
raise InvalidArgumentException("Invalid PointerInput kind '%s'" % kind)
self.type = POINTER
self.kind = kind
@@ -38,7 +37,7 @@
action = dict(type="pointerMove", duration=duration)
action["x"] = x
action["y"] = y
- if isinstance(origin, (WebElement, EventFiringWebElement)):
+ if isinstance(origin, WebElement):
action["origin"] = {"element-6066-11e4-a52e-4f735466cecf": origin.id}
elif origin is not None:
action["origin"] = origin
diff --git a/selenium/webdriver/firefox/webdriver.py b/selenium/webdriver/firefox/webdriver.py
index 01c3c3f..cb5e878 100644
--- a/selenium/webdriver/firefox/webdriver.py
+++ b/selenium/webdriver/firefox/webdriver.py
@@ -25,7 +25,7 @@
from contextlib import contextmanager
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
-from selenium.webdriver.remote.webdriver import WebDriver as RemoteWebDriver
+from selenium.webdriver.remote.webdriver import WebDriver as RemoteWebDriver, WebElement
from .firefox_binary import FirefoxBinary
from .firefox_profile import FirefoxProfile
@@ -341,3 +341,17 @@
driver.get_full_page_screenshot_as_base64()
"""
return self.execute("FULL_PAGE_SCREENSHOT")['value']
+
+ def _wrap_value(self, value):
+ """Overload the _wrap_value so that custom WebElement types can be checked against"""
+ if isinstance(value, dict):
+ converted = {}
+ for key, val in value.items():
+ converted[key] = self._wrap_value(val)
+ return converted
+ elif isinstance(value, (self._web_element_cls, WebElement)):
+ return {'ELEMENT': value.id, 'element-6066-11e4-a52e-4f735466cecf': value.id}
+ elif isinstance(value, list):
+ return list(self._wrap_value(item) for item in value)
+ else:
+ return value
diff --git a/selenium/webdriver/remote/webdriver.py b/selenium/webdriver/remote/webdriver.py
index 78cbec7..9a70dfc 100644
--- a/selenium/webdriver/remote/webdriver.py
+++ b/selenium/webdriver/remote/webdriver.py
@@ -17,6 +17,7 @@
"""The WebDriver implementation."""
+from abc import ABCMeta
import base64
import copy
from contextlib import contextmanager
@@ -36,10 +37,8 @@
NoSuchCookieException,
UnknownMethodException)
from selenium.webdriver.common.by import By
-from selenium.webdriver.common.html5.application_cache import ApplicationCache
-
from selenium.webdriver.common.timeouts import Timeouts
-
+from selenium.webdriver.common.html5.application_cache import ApplicationCache
from selenium.webdriver.support.relative_locator import RelativeBy
try:
@@ -115,7 +114,17 @@
return handler(command_executor, keep_alive=keep_alive)
-class WebDriver(object):
+class BaseWebDriver(object):
+ """
+ Abstract Base Class for all Webdriver subtypes.
+ ABC's allow custom implementations of Webdriver to be registered so that isinstance type checks
+ will succeed.
+ """
+ __metaclass__ = ABCMeta
+ # TODO: After dropping Python 2, use ABC instead of ABCMeta and remove all Python 2 metaclass declarations.
+
+
+class WebDriver(BaseWebDriver):
"""
Controls a browser by sending commands to a remote server.
This server is expected to be running the WebDriver wire protocol
@@ -162,7 +171,7 @@
else:
capabilities.update(desired_capabilities)
self.command_executor = command_executor
- if type(self.command_executor) is bytes or isinstance(self.command_executor, str):
+ if isinstance(self.command_executor, (str, bytes)):
self.command_executor = get_remote_connection(capabilities, command_executor=command_executor, keep_alive=keep_alive)
self._is_remote = True
self.session_id = None
diff --git a/selenium/webdriver/remote/webelement.py b/selenium/webdriver/remote/webelement.py
index f29b89b..b51970c 100644
--- a/selenium/webdriver/remote/webelement.py
+++ b/selenium/webdriver/remote/webelement.py
@@ -21,6 +21,7 @@
import pkgutil
import warnings
import zipfile
+from abc import ABCMeta
from io import BytesIO
from selenium.common.exceptions import WebDriverException
@@ -47,7 +48,15 @@
isDisplayed_js = pkgutil.get_data(_pkg, 'isDisplayed.js').decode('utf8')
-class WebElement(object):
+class BaseWebElement(object):
+ """
+ Abstract Base Class for WebElement.
+ ABC's will allow custom types to be registered as a WebElement to pass type checks.
+ """
+ __metaclass__ = ABCMeta
+
+
+class WebElement(BaseWebElement):
"""Represents a DOM element.
Generally, all interesting operations that interact with a document will be
@@ -138,18 +147,18 @@
"""
- attributeValue = ''
+ attribute_value = ''
if self._w3c:
- attributeValue = self.parent.execute_script(
+ attribute_value = self.parent.execute_script(
"return (%s).apply(null, arguments);" % getAttribute_js,
self, name)
else:
resp = self._execute(Command.GET_ELEMENT_ATTRIBUTE, {'name': name})
- attributeValue = resp.get('value')
- if attributeValue is not None:
- if name != 'value' and attributeValue.lower() in ('true', 'false'):
- attributeValue = attributeValue.lower()
- return attributeValue
+ attribute_value = resp.get('value')
+ if attribute_value is not None:
+ if name != 'value' and attribute_value.lower() in ('true', 'false'):
+ attribute_value = attribute_value.lower()
+ return attribute_value
def is_selected(self):
"""Returns whether the element is selected.
diff --git a/selenium/webdriver/support/event_firing_webdriver.py b/selenium/webdriver/support/event_firing_webdriver.py
index d5fc011..9f72feb 100644
--- a/selenium/webdriver/support/event_firing_webdriver.py
+++ b/selenium/webdriver/support/event_firing_webdriver.py
@@ -26,12 +26,12 @@
def _wrap_elements(result, ef_driver):
- if isinstance(result, WebElement):
+ if isinstance(result, EventFiringWebElement):
+ return result
+ elif isinstance(result, WebElement):
return EventFiringWebElement(result, ef_driver)
elif isinstance(result, list):
return [_wrap_elements(item, ef_driver) for item in result]
- else:
- return result
class EventFiringWebDriver(object):
@@ -353,3 +353,6 @@
except Exception as e:
self._listener.on_exception(e, self._driver)
raise
+
+
+WebElement.register(EventFiringWebElement)
diff --git a/test/unit/selenium/webdriver/remote/test_subtyping.py b/test/unit/selenium/webdriver/remote/test_subtyping.py
new file mode 100644
index 0000000..5f646e1
--- /dev/null
+++ b/test/unit/selenium/webdriver/remote/test_subtyping.py
@@ -0,0 +1,55 @@
+# Licensed to the Software Freedom Conservancy (SFC) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The SFC licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from selenium.webdriver.remote.webdriver import WebElement
+from selenium.webdriver.remote.webdriver import WebDriver
+
+
+def test_web_element_not_subclassed():
+ """A registered subtype of WebElement should work with isinstance checks."""
+ class MyWebElement(object):
+ def __init__(self, parent, id, _w3c=True):
+ self.parent = parent
+ self.id = id
+ self._w3c = _w3c
+
+ # Test that non registered class instance is not instance of Remote WebElement
+ my_web_element = MyWebElement('parent', '1')
+ assert not isinstance(my_web_element, WebElement)
+
+ # Register the class as a subtype of WebElement
+ WebElement.register('MyWebElement')
+ my_registered_web_element = MyWebElement('parent', '2')
+
+ assert isinstance(my_registered_web_element, WebElement)
+
+
+def test_webdriver_not_subclassed():
+ """A registered subtype of WebDriver should work with isinstance checks."""
+ class MyWebDriver(object):
+ def __init__(self, *args, **kwargs):
+ super(MyWebDriver, self).__init__(*args, **kwargs)
+
+ # Test that non registered class instance is not instance of Remote WebDriver
+ my_driver = MyWebDriver()
+ assert not isinstance(my_driver, WebDriver)
+
+ # Register the class as a subtype of WebDriver
+ WebDriver.register(MyWebDriver)
+ my_registered_driver = MyWebDriver()
+
+ assert isinstance(my_registered_driver, MyWebDriver)