diff --git a/CHANGES.md b/CHANGES.md index 11de9ac4..aaa682a8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -19,6 +19,7 @@ * Update UnRar x64 for Windows 6.11 to 6.20 * Update Send2Trash 1.5.0 (66afce7) to 1.8.1b0 (0ef9b32) * Update SimpleJSON 3.16.1 (ce75e60) to 3.18.1 (c891b95) +* Update soupsieve 2.0.2.dev (05086ef) to 2.3.2.post1 (792d566) * Update tmdbsimple 2.6.6 (679e343) to 2.9.1 (9da400a) * Update torrent_parser 0.3.0 (2a4eecb) to 0.4.0 (23b9e11) * Update unidecode module 1.1.1 (632af82) to 1.3.6 (4141992) diff --git a/lib/soupsieve/__init__.py b/lib/soupsieve/__init__.py index a287069f..19999695 100644 --- a/lib/soupsieve/__init__.py +++ b/lib/soupsieve/__init__.py @@ -25,11 +25,14 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations from .__meta__ import __version__, __version_info__ # noqa: F401 from . import css_parser as cp from . import css_match as cm from . import css_types as ct from .util import DEBUG, SelectorSyntaxError # noqa: F401 +import bs4 # type: ignore[import] +from typing import Optional, Any, Iterator, Iterable __all__ = ( 'DEBUG', 'SelectorSyntaxError', 'SoupSieve', @@ -40,15 +43,18 @@ __all__ = ( SoupSieve = cm.SoupSieve -def compile(pattern, namespaces=None, flags=0, **kwargs): # noqa: A001 +def compile( # noqa: A001 + pattern: str, + namespaces: Optional[dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> cm.SoupSieve: """Compile CSS pattern.""" - if namespaces is not None: - namespaces = ct.Namespaces(**namespaces) - - custom = kwargs.get('custom') - if custom is not None: - custom = ct.CustomSelectors(**custom) + ns = ct.Namespaces(namespaces) if namespaces is not None else namespaces # type: Optional[ct.Namespaces] + cs = ct.CustomSelectors(custom) if custom is not None else custom # type: Optional[ct.CustomSelectors] if isinstance(pattern, SoupSieve): if flags: @@ -59,53 +65,103 @@ def compile(pattern, namespaces=None, flags=0, **kwargs): # noqa: A001 raise ValueError("Cannot process 'custom' argument on a compiled selector list") return pattern - return cp._cached_css_compile(pattern, namespaces, custom, flags) + return cp._cached_css_compile(pattern, ns, cs, flags) -def purge(): +def purge() -> None: """Purge cached patterns.""" cp._purge_cache() -def closest(select, tag, namespaces=None, flags=0, **kwargs): +def closest( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> 'bs4.Tag': """Match closest ancestor.""" return compile(select, namespaces, flags, **kwargs).closest(tag) -def match(select, tag, namespaces=None, flags=0, **kwargs): +def match( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> bool: """Match node.""" return compile(select, namespaces, flags, **kwargs).match(tag) -def filter(select, iterable, namespaces=None, flags=0, **kwargs): # noqa: A001 +def filter( # noqa: A001 + select: str, + iterable: Iterable['bs4.Tag'], + namespaces: Optional[dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> list['bs4.Tag']: """Filter list of nodes.""" return compile(select, namespaces, flags, **kwargs).filter(iterable) -def select_one(select, tag, namespaces=None, flags=0, **kwargs): +def select_one( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> 'bs4.Tag': """Select a single tag.""" return compile(select, namespaces, flags, **kwargs).select_one(tag) -def select(select, tag, namespaces=None, limit=0, flags=0, **kwargs): +def select( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[dict[str, str]] = None, + limit: int = 0, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> list['bs4.Tag']: """Select the specified tags.""" return compile(select, namespaces, flags, **kwargs).select(tag, limit) -def iselect(select, tag, namespaces=None, limit=0, flags=0, **kwargs): +def iselect( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[dict[str, str]] = None, + limit: int = 0, + flags: int = 0, + *, + custom: Optional[dict[str, str]] = None, + **kwargs: Any +) -> Iterator['bs4.Tag']: """Iterate the specified tags.""" for el in compile(select, namespaces, flags, **kwargs).iselect(tag, limit): yield el -def escape(ident): +def escape(ident: str) -> str: """Escape identifier.""" return cp.escape(ident) diff --git a/lib/soupsieve/__meta__.py b/lib/soupsieve/__meta__.py index ff6f8e60..34834169 100644 --- a/lib/soupsieve/__meta__.py +++ b/lib/soupsieve/__meta__.py @@ -1,4 +1,5 @@ """Meta related things.""" +from __future__ import annotations from collections import namedtuple import re @@ -79,7 +80,11 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre" """ - def __new__(cls, major, minor, micro, release="final", pre=0, post=0, dev=0): + def __new__( + cls, + major: int, minor: int, micro: int, release: str = "final", + pre: int = 0, post: int = 0, dev: int = 0 + ) -> Version: """Validate version info.""" # Ensure all parts are positive integers. @@ -115,27 +120,27 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre" return super(Version, cls).__new__(cls, major, minor, micro, release, pre, post, dev) - def _is_pre(self): + def _is_pre(self) -> bool: """Is prerelease.""" - return self.pre > 0 + return bool(self.pre > 0) - def _is_dev(self): + def _is_dev(self) -> bool: """Is development.""" return bool(self.release < "alpha") - def _is_post(self): + def _is_post(self) -> bool: """Is post.""" - return self.post > 0 + return bool(self.post > 0) - def _get_dev_status(self): # pragma: no cover + def _get_dev_status(self) -> str: # pragma: no cover """Get development status string.""" return DEV_STATUS[self.release] - def _get_canonical(self): + def _get_canonical(self) -> str: """Get the canonical output string.""" # Assemble major, minor, micro version and append `pre`, `post`, or `dev` if needed.. @@ -153,11 +158,14 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre" return ver -def parse_version(ver, pre=False): +def parse_version(ver: str) -> Version: """Parse version into a comparable Version tuple.""" m = RE_VER.match(ver) + if m is None: + raise ValueError("'{}' is not a valid version".format(ver)) + # Handle major, minor, micro major = int(m.group('major')) minor = int(m.group('minor')) if m.group('minor') else 0 @@ -185,5 +193,5 @@ def parse_version(ver, pre=False): return Version(major, minor, micro, release, pre, post, dev) -__version_info__ = Version(2, 0, 2, ".dev") +__version_info__ = Version(2, 5, 0, "final", post=1) __version__ = __version_info__._get_canonical() diff --git a/lib/soupsieve/css_match.py b/lib/soupsieve/css_match.py index 812f84b9..b06b25ee 100644 --- a/lib/soupsieve/css_match.py +++ b/lib/soupsieve/css_match.py @@ -1,11 +1,12 @@ """CSS matcher.""" +from __future__ import annotations from datetime import datetime from . import util import re -from .import css_types as ct +from . import css_types as ct import unicodedata - -import bs4 as bs4 +import bs4 # type: ignore[import] +from typing import Iterator, Iterable, Any, Optional, Callable, Sequence, cast # noqa: F401 # Empty tag pattern (whitespace okay) RE_NOT_EMPTY = re.compile('[^ \t\r\n\f]') @@ -55,7 +56,7 @@ FEB_LEAP_MONTH = 29 DAYS_IN_WEEK = 7 -class _FakeParent(object): +class _FakeParent: """ Fake parent class. @@ -64,87 +65,90 @@ class _FakeParent(object): fake parent so we can traverse the root element as a child. """ - def __init__(self, element): + def __init__(self, element: bs4.Tag) -> None: """Initialize.""" self.contents = [element] - def __len__(self): + def __len__(self) -> bs4.PageElement: """Length.""" return len(self.contents) -class _DocumentNav(object): +class _DocumentNav: """Navigate a Beautiful Soup document.""" @classmethod - def assert_valid_input(cls, tag): + def assert_valid_input(cls, tag: Any) -> None: """Check if valid input tag or document.""" # Fail on unexpected types. if not cls.is_tag(tag): - raise TypeError("Expected a BeautifulSoup 'Tag', but instead recieved type {}".format(type(tag))) + raise TypeError("Expected a BeautifulSoup 'Tag', but instead received type {}".format(type(tag))) @staticmethod - def is_doc(obj): + def is_doc(obj: bs4.Tag) -> bool: """Is `BeautifulSoup` object.""" return isinstance(obj, bs4.BeautifulSoup) @staticmethod - def is_tag(obj): + def is_tag(obj: bs4.PageElement) -> bool: """Is tag.""" return isinstance(obj, bs4.Tag) @staticmethod - def is_declaration(obj): # pragma: no cover + def is_declaration(obj: bs4.PageElement) -> bool: # pragma: no cover """Is declaration.""" return isinstance(obj, bs4.Declaration) @staticmethod - def is_cdata(obj): + def is_cdata(obj: bs4.PageElement) -> bool: """Is CDATA.""" return isinstance(obj, bs4.CData) @staticmethod - def is_processing_instruction(obj): # pragma: no cover + def is_processing_instruction(obj: bs4.PageElement) -> bool: # pragma: no cover """Is processing instruction.""" return isinstance(obj, bs4.ProcessingInstruction) @staticmethod - def is_navigable_string(obj): + def is_navigable_string(obj: bs4.PageElement) -> bool: """Is navigable string.""" return isinstance(obj, bs4.NavigableString) @staticmethod - def is_special_string(obj): + def is_special_string(obj: bs4.PageElement) -> bool: """Is special string.""" return isinstance(obj, (bs4.Comment, bs4.Declaration, bs4.CData, bs4.ProcessingInstruction, bs4.Doctype)) @classmethod - def is_content_string(cls, obj): + def is_content_string(cls, obj: bs4.PageElement) -> bool: """Check if node is content string.""" return cls.is_navigable_string(obj) and not cls.is_special_string(obj) @staticmethod - def create_fake_parent(el): + def create_fake_parent(el: bs4.Tag) -> _FakeParent: """Create fake parent for a given element.""" return _FakeParent(el) @staticmethod - def is_xml_tree(el): + def is_xml_tree(el: bs4.Tag) -> bool: """Check if element (or document) is from a XML tree.""" - return el._is_xml + return bool(el._is_xml) - def is_iframe(self, el): + def is_iframe(self, el: bs4.Tag) -> bool: """Check if element is an `iframe`.""" - return ((el.name if self.is_xml_tree(el) else util.lower(el.name)) == 'iframe') and self.is_html_tag(el) + return bool( + ((el.name if self.is_xml_tree(el) else util.lower(el.name)) == 'iframe') and + self.is_html_tag(el) # type: ignore[attr-defined] + ) - def is_root(self, el): + def is_root(self, el: bs4.Tag) -> bool: """ Return whether element is a root element. @@ -152,19 +156,26 @@ class _DocumentNav(object): and we check if it is the root element under an `iframe`. """ - root = self.root and self.root is el + root = self.root and self.root is el # type: ignore[attr-defined] if not root: parent = self.get_parent(el) - root = parent is not None and self.is_html and self.is_iframe(parent) + root = parent is not None and self.is_html and self.is_iframe(parent) # type: ignore[attr-defined] return root - def get_contents(self, el, no_iframe=False): + def get_contents(self, el: bs4.Tag, no_iframe: bool = False) -> Iterator[bs4.PageElement]: """Get contents or contents in reverse.""" if not no_iframe or not self.is_iframe(el): for content in el.contents: yield content - def get_children(self, el, start=None, reverse=False, tags=True, no_iframe=False): + def get_children( + self, + el: bs4.Tag, + start: Optional[int] = None, + reverse: bool = False, + tags: bool = True, + no_iframe: bool = False + ) -> Iterator[bs4.PageElement]: """Get children.""" if not no_iframe or not self.is_iframe(el): @@ -183,7 +194,12 @@ class _DocumentNav(object): if not tags or self.is_tag(node): yield node - def get_descendants(self, el, tags=True, no_iframe=False): + def get_descendants( + self, + el: bs4.Tag, + tags: bool = True, + no_iframe: bool = False + ) -> Iterator[bs4.PageElement]: """Get descendants.""" if not no_iframe or not self.is_iframe(el): @@ -214,7 +230,7 @@ class _DocumentNav(object): if not tags or is_tag: yield child - def get_parent(self, el, no_iframe=False): + def get_parent(self, el: bs4.Tag, no_iframe: bool = False) -> bs4.Tag: """Get parent.""" parent = el.parent @@ -223,25 +239,25 @@ class _DocumentNav(object): return parent @staticmethod - def get_tag_name(el): + def get_tag_name(el: bs4.Tag) -> Optional[str]: """Get tag.""" - return el.name + return cast(Optional[str], el.name) @staticmethod - def get_prefix_name(el): + def get_prefix_name(el: bs4.Tag) -> Optional[str]: """Get prefix.""" - return el.prefix + return cast(Optional[str], el.prefix) @staticmethod - def get_uri(el): + def get_uri(el: bs4.Tag) -> Optional[str]: """Get namespace `URI`.""" - return el.namespace + return cast(Optional[str], el.namespace) @classmethod - def get_next(cls, el, tags=True): + def get_next(cls, el: bs4.Tag, tags: bool = True) -> bs4.PageElement: """Get next sibling tag.""" sibling = el.next_sibling @@ -250,7 +266,7 @@ class _DocumentNav(object): return sibling @classmethod - def get_previous(cls, el, tags=True): + def get_previous(cls, el: bs4.Tag, tags: bool = True) -> bs4.PageElement: """Get previous sibling tag.""" sibling = el.previous_sibling @@ -259,7 +275,7 @@ class _DocumentNav(object): return sibling @staticmethod - def has_html_ns(el): + def has_html_ns(el: bs4.Tag) -> bool: """ Check if element has an HTML namespace. @@ -268,60 +284,103 @@ class _DocumentNav(object): """ ns = getattr(el, 'namespace') if el else None - return ns and ns == NS_XHTML + return bool(ns and ns == NS_XHTML) @staticmethod - def split_namespace(el, attr_name): + def split_namespace(el: bs4.Tag, attr_name: str) -> tuple[Optional[str], Optional[str]]: """Return namespace and attribute name without the prefix.""" return getattr(attr_name, 'namespace', None), getattr(attr_name, 'name', None) - @staticmethod - def get_attribute_by_name(el, name, default=None): + @classmethod + def normalize_value(cls, value: Any) -> str | Sequence[str]: + """Normalize the value to be a string or list of strings.""" + + # Treat `None` as empty string. + if value is None: + return '' + + # Pass through strings + if (isinstance(value, str)): + return value + + # If it's a byte string, convert it to Unicode, treating it as UTF-8. + if isinstance(value, bytes): + return value.decode("utf8") + + # BeautifulSoup supports sequences of attribute values, so make sure the children are strings. + if isinstance(value, Sequence): + new_value = [] + for v in value: + if not isinstance(v, (str, bytes)) and isinstance(v, Sequence): + # This is most certainly a user error and will crash and burn later. + # To keep things working, we'll do what we do with all objects, + # And convert them to strings. + new_value.append(str(v)) + else: + # Convert the child to a string + new_value.append(cast(str, cls.normalize_value(v))) + return new_value + + # Try and make anything else a string + return str(value) + + @classmethod + def get_attribute_by_name( + cls, + el: bs4.Tag, + name: str, + default: Optional[str | Sequence[str]] = None + ) -> Optional[str | Sequence[str]]: """Get attribute by name.""" value = default if el._is_xml: try: - value = el.attrs[name] + value = cls.normalize_value(el.attrs[name]) except KeyError: pass else: for k, v in el.attrs.items(): if util.lower(k) == name: - value = v + value = cls.normalize_value(v) break return value - @staticmethod - def iter_attributes(el): + @classmethod + def iter_attributes(cls, el: bs4.Tag) -> Iterator[tuple[str, Optional[str | Sequence[str]]]]: """Iterate attributes.""" for k, v in el.attrs.items(): - yield k, v + yield k, cls.normalize_value(v) @classmethod - def get_classes(cls, el): + def get_classes(cls, el: bs4.Tag) -> Sequence[str]: """Get classes.""" classes = cls.get_attribute_by_name(el, 'class', []) if isinstance(classes, str): classes = RE_NOT_WS.findall(classes) - return classes + return cast(Sequence[str], classes) - def get_text(self, el, no_iframe=False): + def get_text(self, el: bs4.Tag, no_iframe: bool = False) -> str: """Get text.""" return ''.join( [node for node in self.get_descendants(el, tags=False, no_iframe=no_iframe) if self.is_content_string(node)] ) + def get_own_text(self, el: bs4.Tag, no_iframe: bool = False) -> list[str]: + """Get Own Text.""" -class Inputs(object): + return [node for node in self.get_contents(el, no_iframe=no_iframe) if self.is_content_string(node)] + + +class Inputs: """Class for parsing and validating input items.""" @staticmethod - def validate_day(year, month, day): + def validate_day(year: int, month: int, day: int) -> bool: """Validate day.""" max_days = LONG_MONTH @@ -332,7 +391,7 @@ class Inputs(object): return 1 <= day <= max_days @staticmethod - def validate_week(year, week): + def validate_week(year: int, week: int) -> bool: """Validate week.""" max_week = datetime.strptime("{}-{}-{}".format(12, 31, year), "%m-%d-%Y").isocalendar()[1] @@ -341,34 +400,36 @@ class Inputs(object): return 1 <= week <= max_week @staticmethod - def validate_month(month): + def validate_month(month: int) -> bool: """Validate month.""" return 1 <= month <= 12 @staticmethod - def validate_year(year): + def validate_year(year: int) -> bool: """Validate year.""" return 1 <= year @staticmethod - def validate_hour(hour): + def validate_hour(hour: int) -> bool: """Validate hour.""" return 0 <= hour <= 23 @staticmethod - def validate_minutes(minutes): + def validate_minutes(minutes: int) -> bool: """Validate minutes.""" return 0 <= minutes <= 59 @classmethod - def parse_value(cls, itype, value): + def parse_value(cls, itype: str, value: Optional[str]) -> Optional[tuple[float, ...]]: """Parse the input value.""" - parsed = None + parsed = None # type: Optional[tuple[float, ...]] + if value is None: + return value if itype == "date": m = RE_DATE.match(value) if m: @@ -414,23 +475,29 @@ class Inputs(object): elif itype in ("number", "range"): m = RE_NUM.match(value) if m: - parsed = float(m.group('value')) + parsed = (float(m.group('value')),) return parsed -class _Match(object): +class CSSMatch(_DocumentNav): """Perform CSS matching.""" - def __init__(self, selectors, scope, namespaces, flags): + def __init__( + self, + selectors: ct.SelectorList, + scope: bs4.Tag, + namespaces: Optional[ct.Namespaces], + flags: int + ) -> None: """Initialize.""" self.assert_valid_input(scope) self.tag = scope - self.cached_meta_lang = [] - self.cached_default_forms = [] - self.cached_indeterminate_forms = [] + self.cached_meta_lang = [] # type: list[tuple[str, str]] + self.cached_default_forms = [] # type: list[tuple[bs4.Tag, bs4.Tag]] + self.cached_indeterminate_forms = [] # type: list[tuple[bs4.Tag, str, bool]] self.selectors = selectors - self.namespaces = {} if namespaces is None else namespaces + self.namespaces = {} if namespaces is None else namespaces # type: ct.Namespaces | dict[str, str] self.flags = flags self.iframe_restrict = False @@ -456,12 +523,12 @@ class _Match(object): self.is_xml = self.is_xml_tree(doc) self.is_html = not self.is_xml or self.has_html_namespace - def supports_namespaces(self): + def supports_namespaces(self) -> bool: """Check if namespaces are supported in the HTML type.""" return self.is_xml or self.has_html_namespace - def get_tag_ns(self, el): + def get_tag_ns(self, el: bs4.Tag) -> str: """Get tag namespace.""" if self.supports_namespaces(): @@ -473,24 +540,24 @@ class _Match(object): namespace = NS_XHTML return namespace - def is_html_tag(self, el): + def is_html_tag(self, el: bs4.Tag) -> bool: """Check if tag is in HTML namespace.""" return self.get_tag_ns(el) == NS_XHTML - def get_tag(self, el): + def get_tag(self, el: bs4.Tag) -> Optional[str]: """Get tag.""" name = self.get_tag_name(el) return util.lower(name) if name is not None and not self.is_xml else name - def get_prefix(self, el): + def get_prefix(self, el: bs4.Tag) -> Optional[str]: """Get prefix.""" prefix = self.get_prefix_name(el) return util.lower(prefix) if prefix is not None and not self.is_xml else prefix - def find_bidi(self, el): + def find_bidi(self, el: bs4.Tag) -> Optional[int]: """Get directionality from element text.""" for node in self.get_children(el, tags=False): @@ -526,7 +593,7 @@ class _Match(object): return ct.SEL_DIR_LTR if bidi == 'L' else ct.SEL_DIR_RTL return None - def extended_language_filter(self, lang_range, lang_tag): + def extended_language_filter(self, lang_range: str, lang_tag: str) -> bool: """Filter the language tags.""" match = True @@ -534,13 +601,18 @@ class _Match(object): ranges = lang_range.split('-') subtags = lang_tag.lower().split('-') length = len(ranges) + slength = len(subtags) rindex = 0 sindex = 0 r = ranges[rindex] s = subtags[sindex] + # Empty specified language should match unspecified language attributes + if length == 1 and slength == 1 and not r and r == s: + return True + # Primary tag needs to match - if r != '*' and r != s: + if (r != '*' and r != s) or (r == '*' and slength == 1 and not s): match = False rindex += 1 @@ -577,7 +649,12 @@ class _Match(object): return match - def match_attribute_name(self, el, attr, prefix): + def match_attribute_name( + self, + el: bs4.Tag, + attr: str, + prefix: Optional[str] + ) -> Optional[str | Sequence[str]]: """Match attribute name and return value if it exists.""" value = None @@ -625,13 +702,13 @@ class _Match(object): break return value - def match_namespace(self, el, tag): + def match_namespace(self, el: bs4.Tag, tag: ct.SelectorTag) -> bool: """Match the namespace of the element.""" match = True namespace = self.get_tag_ns(el) default_namespace = self.namespaces.get('') - tag_ns = '' if tag.prefix is None else self.namespaces.get(tag.prefix, None) + tag_ns = '' if tag.prefix is None else self.namespaces.get(tag.prefix) # We must match the default namespace if one is not provided if tag.prefix is None and (default_namespace is not None and namespace != default_namespace): match = False @@ -646,27 +723,26 @@ class _Match(object): match = False return match - def match_attributes(self, el, attributes): + def match_attributes(self, el: bs4.Tag, attributes: tuple[ct.SelectorAttribute, ...]) -> bool: """Match attributes.""" match = True if attributes: for a in attributes: - value = self.match_attribute_name(el, a.attribute, a.prefix) + temp = self.match_attribute_name(el, a.attribute, a.prefix) pattern = a.xml_type_pattern if self.is_xml and a.xml_type_pattern else a.pattern - if isinstance(value, list): - value = ' '.join(value) - if value is None: + if temp is None: match = False break - elif pattern is None: + value = temp if isinstance(temp, str) else ' '.join(temp) + if pattern is None: continue elif pattern.match(value) is None: match = False break return match - def match_tagname(self, el, tag): + def match_tagname(self, el: bs4.Tag, tag: ct.SelectorTag) -> bool: """Match tag name.""" name = (util.lower(tag.name) if not self.is_xml and tag.name is not None else tag.name) @@ -675,7 +751,7 @@ class _Match(object): name not in (self.get_tag(el), '*') ) - def match_tag(self, el, tag): + def match_tag(self, el: bs4.Tag, tag: Optional[ct.SelectorTag]) -> bool: """Match the tag.""" match = True @@ -687,10 +763,14 @@ class _Match(object): match = False return match - def match_past_relations(self, el, relation): + def match_past_relations(self, el: bs4.Tag, relation: ct.SelectorList) -> bool: """Match past relationship.""" found = False + # I don't think this can ever happen, but it makes `mypy` happy + if isinstance(relation[0], ct.SelectorNull): # pragma: no cover + return found + if relation[0].rel_type == REL_PARENT: parent = self.get_parent(el, no_iframe=self.iframe_restrict) while not found and parent: @@ -711,21 +791,28 @@ class _Match(object): found = self.match_selectors(sibling, relation) return found - def match_future_child(self, parent, relation, recursive=False): + def match_future_child(self, parent: bs4.Tag, relation: ct.SelectorList, recursive: bool = False) -> bool: """Match future child.""" match = False - children = self.get_descendants if recursive else self.get_children + if recursive: + children = self.get_descendants # type: Callable[..., Iterator[bs4.Tag]] + else: + children = self.get_children for child in children(parent, no_iframe=self.iframe_restrict): match = self.match_selectors(child, relation) if match: break return match - def match_future_relations(self, el, relation): + def match_future_relations(self, el: bs4.Tag, relation: ct.SelectorList) -> bool: """Match future relationship.""" found = False + # I don't think this can ever happen, but it makes `mypy` happy + if isinstance(relation[0], ct.SelectorNull): # pragma: no cover + return found + if relation[0].rel_type == REL_HAS_PARENT: found = self.match_future_child(el, relation, True) elif relation[0].rel_type == REL_HAS_CLOSE_PARENT: @@ -741,11 +828,14 @@ class _Match(object): found = self.match_selectors(sibling, relation) return found - def match_relations(self, el, relation): + def match_relations(self, el: bs4.Tag, relation: ct.SelectorList) -> bool: """Match relationship to other elements.""" found = False + if isinstance(relation[0], ct.SelectorNull) or relation[0].rel_type is None: + return found + if relation[0].rel_type.startswith(':'): found = self.match_future_relations(el, relation) else: @@ -753,7 +843,7 @@ class _Match(object): return found - def match_id(self, el, ids): + def match_id(self, el: bs4.Tag, ids: tuple[str, ...]) -> bool: """Match element's ID.""" found = True @@ -763,7 +853,7 @@ class _Match(object): break return found - def match_classes(self, el, classes): + def match_classes(self, el: bs4.Tag, classes: tuple[str, ...]) -> bool: """Match element's classes.""" current_classes = self.get_classes(el) @@ -774,7 +864,7 @@ class _Match(object): break return found - def match_root(self, el): + def match_root(self, el: bs4.Tag) -> bool: """Match element as root.""" is_root = self.is_root(el) @@ -800,20 +890,20 @@ class _Match(object): sibling = self.get_next(sibling, tags=False) return is_root - def match_scope(self, el): + def match_scope(self, el: bs4.Tag) -> bool: """Match element as scope.""" return self.scope is el - def match_nth_tag_type(self, el, child): + def match_nth_tag_type(self, el: bs4.Tag, child: bs4.Tag) -> bool: """Match tag type for `nth` matches.""" - return( + return ( (self.get_tag(child) == self.get_tag(el)) and (self.get_tag_ns(child) == self.get_tag_ns(el)) ) - def match_nth(self, el, nth): + def match_nth(self, el: bs4.Tag, nth: bs4.Tag) -> bool: """Match `nth` elements.""" matched = True @@ -914,7 +1004,7 @@ class _Match(object): break return matched - def match_empty(self, el): + def match_empty(self, el: bs4.Tag) -> bool: """Check if element is empty (if requested).""" is_empty = True @@ -927,7 +1017,7 @@ class _Match(object): break return is_empty - def match_subselectors(self, el, selectors): + def match_subselectors(self, el: bs4.Tag, selectors: tuple[ct.SelectorList, ...]) -> bool: """Match selectors.""" match = True @@ -936,24 +1026,35 @@ class _Match(object): match = False return match - def match_contains(self, el, contains): + def match_contains(self, el: bs4.Tag, contains: tuple[ct.SelectorContains, ...]) -> bool: """Match element if it contains text.""" match = True - content = None + content = None # type: Optional[str | Sequence[str]] for contain_list in contains: if content is None: - content = self.get_text(el, no_iframe=self.is_html) + if contain_list.own: + content = self.get_own_text(el, no_iframe=self.is_html) + else: + content = self.get_text(el, no_iframe=self.is_html) found = False for text in contain_list.text: - if text in content: - found = True - break + if contain_list.own: + for c in content: + if text in c: + found = True + break + if found: + break + else: + if text in content: + found = True + break if not found: match = False return match - def match_default(self, el): + def match_default(self, el: bs4.Tag) -> bool: """Match default.""" match = False @@ -986,19 +1087,19 @@ class _Match(object): if name in ('input', 'button'): v = self.get_attribute_by_name(child, 'type', '') if v and util.lower(v) == 'submit': - self.cached_default_forms.append([form, child]) + self.cached_default_forms.append((form, child)) if el is child: match = True break return match - def match_indeterminate(self, el): + def match_indeterminate(self, el: bs4.Tag) -> bool: """Match default.""" match = False - name = self.get_attribute_by_name(el, 'name') + name = cast(str, self.get_attribute_by_name(el, 'name')) - def get_parent_form(el): + def get_parent_form(el: bs4.Tag) -> Optional[bs4.Tag]: """Find this input's form.""" form = None parent = self.get_parent(el, no_iframe=True) @@ -1049,11 +1150,11 @@ class _Match(object): break if not checked: match = True - self.cached_indeterminate_forms.append([form, name, match]) + self.cached_indeterminate_forms.append((form, name, match)) return match - def match_lang(self, el, langs): + def match_lang(self, el: bs4.Tag, langs: tuple[ct.SelectorLang, ...]) -> bool: """Match languages.""" match = False @@ -1088,7 +1189,7 @@ class _Match(object): break # Use cached meta language. - if not found_lang and self.cached_meta_lang: + if found_lang is None and self.cached_meta_lang: for cache in self.cached_meta_lang: if root is cache[0]: found_lang = cache[1] @@ -1120,26 +1221,26 @@ class _Match(object): content = v if c_lang and content: found_lang = content - self.cached_meta_lang.append((root, found_lang)) + self.cached_meta_lang.append((cast(str, root), cast(str, found_lang))) break - if found_lang: + if found_lang is not None: break - if not found_lang: - self.cached_meta_lang.append((root, False)) + if found_lang is None: + self.cached_meta_lang.append((cast(str, root), '')) # If we determined a language, compare. - if found_lang: + if found_lang is not None: for patterns in langs: match = False for pattern in patterns: - if self.extended_language_filter(pattern, found_lang): + if self.extended_language_filter(pattern, cast(str, found_lang)): match = True if not match: break return match - def match_dir(self, el, directionality): + def match_dir(self, el: bs4.Tag, directionality: int) -> bool: """Check directionality.""" # If we have to match both left and right, we can't match either. @@ -1171,13 +1272,13 @@ class _Match(object): # Auto handling for text inputs if ((is_input and itype in ('text', 'search', 'tel', 'url', 'email')) or is_textarea) and direction == 0: if is_textarea: - value = [] + temp = [] for node in self.get_contents(el, no_iframe=True): if self.is_content_string(node): - value.append(node) - value = ''.join(value) + temp.append(node) + value = ''.join(temp) else: - value = self.get_attribute_by_name(el, 'value', '') + value = cast(str, self.get_attribute_by_name(el, 'value', '')) if value: for c in value: bidi = unicodedata.bidirectional(c) @@ -1202,7 +1303,7 @@ class _Match(object): # Match parents direction return self.match_dir(self.get_parent(el, no_iframe=True), directionality) - def match_range(self, el, condition): + def match_range(self, el: bs4.Tag, condition: int) -> bool: """ Match range. @@ -1215,20 +1316,14 @@ class _Match(object): out_of_range = False itype = util.lower(self.get_attribute_by_name(el, 'type')) - mn = self.get_attribute_by_name(el, 'min', None) - if mn is not None: - mn = Inputs.parse_value(itype, mn) - mx = self.get_attribute_by_name(el, 'max', None) - if mx is not None: - mx = Inputs.parse_value(itype, mx) + mn = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'min', None))) + mx = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'max', None))) # There is no valid min or max, so we cannot evaluate a range if mn is None and mx is None: return False - value = self.get_attribute_by_name(el, 'value', None) - if value is not None: - value = Inputs.parse_value(itype, value) + value = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'value', None))) if value is not None: if itype in ("date", "datetime-local", "month", "week", "number", "range"): if mn is not None and value < mn: @@ -1248,7 +1343,7 @@ class _Match(object): return not out_of_range if condition & ct.SEL_IN_RANGE else out_of_range - def match_defined(self, el): + def match_defined(self, el: bs4.Tag) -> bool: """ Match defined. @@ -1264,12 +1359,14 @@ class _Match(object): name = self.get_tag(el) return ( - name.find('-') == -1 or - name.find(':') != -1 or - self.get_prefix(el) is not None + name is not None and ( + name.find('-') == -1 or + name.find(':') != -1 or + self.get_prefix(el) is not None + ) ) - def match_placeholder_shown(self, el): + def match_placeholder_shown(self, el: bs4.Tag) -> bool: """ Match placeholder shown according to HTML spec. @@ -1284,7 +1381,7 @@ class _Match(object): return match - def match_selectors(self, el, selectors): + def match_selectors(self, el: bs4.Tag, selectors: ct.SelectorList) -> bool: """Check if element matches one of the selectors.""" match = False @@ -1356,7 +1453,7 @@ class _Match(object): if selector.flags & DIR_FLAGS and not self.match_dir(el, selector.flags & DIR_FLAGS): continue # Validate that the tag contains the specified text. - if not self.match_contains(el, selector.contains): + if selector.contains and not self.match_contains(el, selector.contains): continue match = not is_not break @@ -1368,21 +1465,20 @@ class _Match(object): return match - def select(self, limit=0): + def select(self, limit: int = 0) -> Iterator[bs4.Tag]: """Match all tags under the targeted tag.""" - if limit < 1: - limit = None + lim = None if limit < 1 else limit for child in self.get_descendants(self.tag): if self.match(child): yield child - if limit is not None: - limit -= 1 - if limit < 1: + if lim is not None: + lim -= 1 + if lim < 1: break - def closest(self): + def closest(self) -> Optional[bs4.Tag]: """Match closest ancestor.""" current = self.tag @@ -1394,30 +1490,39 @@ class _Match(object): current = self.get_parent(current) return closest - def filter(self): # noqa A001 + def filter(self) -> list[bs4.Tag]: # noqa A001 """Filter tag's children.""" return [tag for tag in self.get_contents(self.tag) if not self.is_navigable_string(tag) and self.match(tag)] - def match(self, el): + def match(self, el: bs4.Tag) -> bool: """Match.""" return not self.is_doc(el) and self.is_tag(el) and self.match_selectors(el, self.selectors) -class CSSMatch(_DocumentNav, _Match): - """The Beautiful Soup CSS match class.""" - - class SoupSieve(ct.Immutable): """Compiled Soup Sieve selector matching object.""" + pattern: str + selectors: ct.SelectorList + namespaces: Optional[ct.Namespaces] + custom: dict[str, str] + flags: int + __slots__ = ("pattern", "selectors", "namespaces", "custom", "flags", "_hash") - def __init__(self, pattern, selectors, namespaces, custom, flags): + def __init__( + self, + pattern: str, + selectors: ct.SelectorList, + namespaces: Optional[ct.Namespaces], + custom: Optional[ct.CustomSelectors], + flags: int + ): """Initialize.""" - super(SoupSieve, self).__init__( + super().__init__( pattern=pattern, selectors=selectors, namespaces=namespaces, @@ -1425,17 +1530,17 @@ class SoupSieve(ct.Immutable): flags=flags ) - def match(self, tag): + def match(self, tag: bs4.Tag) -> bool: """Match.""" return CSSMatch(self.selectors, tag, self.namespaces, self.flags).match(tag) - def closest(self, tag): + def closest(self, tag: bs4.Tag) -> bs4.Tag: """Match closest ancestor.""" return CSSMatch(self.selectors, tag, self.namespaces, self.flags).closest() - def filter(self, iterable): # noqa A001 + def filter(self, iterable: Iterable[bs4.Tag]) -> list[bs4.Tag]: # noqa A001 """ Filter. @@ -1452,24 +1557,24 @@ class SoupSieve(ct.Immutable): else: return [node for node in iterable if not CSSMatch.is_navigable_string(node) and self.match(node)] - def select_one(self, tag): + def select_one(self, tag: bs4.Tag) -> bs4.Tag: """Select a single tag.""" tags = self.select(tag, limit=1) return tags[0] if tags else None - def select(self, tag, limit=0): + def select(self, tag: bs4.Tag, limit: int = 0) -> list[bs4.Tag]: """Select the specified tags.""" return list(self.iselect(tag, limit)) - def iselect(self, tag, limit=0): + def iselect(self, tag: bs4.Tag, limit: int = 0) -> Iterator[bs4.Tag]: """Iterate the specified tags.""" for el in CSSMatch(self.selectors, tag, self.namespaces, self.flags).select(limit): yield el - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Representation.""" return "SoupSieve(pattern={!r}, namespaces={!r}, custom={!r}, flags={!r})".format( diff --git a/lib/soupsieve/css_parser.py b/lib/soupsieve/css_parser.py index 4fab9bab..3cd3e731 100644 --- a/lib/soupsieve/css_parser.py +++ b/lib/soupsieve/css_parser.py @@ -1,10 +1,13 @@ """CSS selector parser.""" +from __future__ import annotations import re from functools import lru_cache from . import util from . import css_match as cm from . import css_types as ct from .util import SelectorSyntaxError +import warnings +from typing import Optional, Match, Any, Iterator, cast UNICODE_REPLACEMENT_CHAR = 0xFFFD @@ -59,6 +62,8 @@ PSEUDO_SIMPLE_NO_MATCH = { # Complex pseudo classes that take selector lists PSEUDO_COMPLEX = { ':contains', + ':-soup-contains', + ':-soup-contains-own', ':has', ':is', ':matches', @@ -193,32 +198,42 @@ FLG_OPEN = 0x40 FLG_IN_RANGE = 0x80 FLG_OUT_OF_RANGE = 0x100 FLG_PLACEHOLDER_SHOWN = 0x200 +FLG_FORGIVE = 0x400 # Maximum cached patterns to store _MAXCACHE = 500 @lru_cache(maxsize=_MAXCACHE) -def _cached_css_compile(pattern, namespaces, custom, flags): +def _cached_css_compile( + pattern: str, + namespaces: Optional[ct.Namespaces], + custom: Optional[ct.CustomSelectors], + flags: int +) -> cm.SoupSieve: """Cached CSS compile.""" custom_selectors = process_custom(custom) return cm.SoupSieve( pattern, - CSSParser(pattern, custom=custom_selectors, flags=flags).process_selectors(), + CSSParser( + pattern, + custom=custom_selectors, + flags=flags + ).process_selectors(), namespaces, custom, flags ) -def _purge_cache(): +def _purge_cache() -> None: """Purge the cache.""" _cached_css_compile.cache_clear() -def process_custom(custom): +def process_custom(custom: Optional[ct.CustomSelectors]) -> dict[str, str | ct.SelectorList]: """Process custom.""" custom_selectors = {} @@ -233,14 +248,14 @@ def process_custom(custom): return custom_selectors -def css_unescape(content, string=False): +def css_unescape(content: str, string: bool = False) -> str: """ Unescape CSS value. Strings allow for spanning the value on multiple strings by escaping a new line. """ - def replace(m): + def replace(m: Match[str]) -> str: """Replace with the appropriate substitute.""" if m.group(1): @@ -260,7 +275,7 @@ def css_unescape(content, string=False): return (RE_CSS_ESC if not string else RE_CSS_STR_ESC).sub(replace, content) -def escape(ident): +def escape(ident: str) -> str: """Escape identifier.""" string = [] @@ -288,21 +303,21 @@ def escape(ident): return ''.join(string) -class SelectorPattern(object): +class SelectorPattern: """Selector pattern.""" - def __init__(self, name, pattern): + def __init__(self, name: str, pattern: str) -> None: """Initialize.""" self.name = name self.re_pattern = re.compile(pattern, re.I | re.X | re.U) - def get_name(self): + def get_name(self) -> str: """Get name.""" return self.name - def match(self, selector, index, flags): + def match(self, selector: str, index: int, flags: int) -> Optional[Match[str]]: """Match the selector.""" return self.re_pattern.match(selector, index) @@ -311,7 +326,7 @@ class SelectorPattern(object): class SpecialPseudoPattern(SelectorPattern): """Selector pattern.""" - def __init__(self, patterns): + def __init__(self, patterns: tuple[tuple[str, tuple[str, ...], str, type[SelectorPattern]], ...]) -> None: """Initialize.""" self.patterns = {} @@ -321,15 +336,15 @@ class SpecialPseudoPattern(SelectorPattern): for pseudo in p[1]: self.patterns[pseudo] = pattern - self.matched_name = None + self.matched_name = None # type: Optional[SelectorPattern] self.re_pseudo_name = re.compile(PAT_PSEUDO_CLASS_SPECIAL, re.I | re.X | re.U) - def get_name(self): + def get_name(self) -> str: """Get name.""" - return self.matched_name.get_name() + return '' if self.matched_name is None else self.matched_name.get_name() - def match(self, selector, index, flags): + def match(self, selector: str, index: int, flags: int) -> Optional[Match[str]]: """Match the selector.""" pseudo = None @@ -345,7 +360,7 @@ class SpecialPseudoPattern(SelectorPattern): return pseudo -class _Selector(object): +class _Selector: """ Intermediate selector class. @@ -354,23 +369,23 @@ class _Selector(object): the data in an object that can be pickled and hashed. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """Initialize.""" - self.tag = kwargs.get('tag', None) - self.ids = kwargs.get('ids', []) - self.classes = kwargs.get('classes', []) - self.attributes = kwargs.get('attributes', []) - self.nth = kwargs.get('nth', []) - self.selectors = kwargs.get('selectors', []) - self.relations = kwargs.get('relations', []) - self.rel_type = kwargs.get('rel_type', None) - self.contains = kwargs.get('contains', []) - self.lang = kwargs.get('lang', []) - self.flags = kwargs.get('flags', 0) - self.no_match = kwargs.get('no_match', False) + self.tag = kwargs.get('tag', None) # type: Optional[ct.SelectorTag] + self.ids = kwargs.get('ids', []) # type: list[str] + self.classes = kwargs.get('classes', []) # type: list[str] + self.attributes = kwargs.get('attributes', []) # type: list[ct.SelectorAttribute] + self.nth = kwargs.get('nth', []) # type: list[ct.SelectorNth] + self.selectors = kwargs.get('selectors', []) # type: list[ct.SelectorList] + self.relations = kwargs.get('relations', []) # type: list[_Selector] + self.rel_type = kwargs.get('rel_type', None) # type: Optional[str] + self.contains = kwargs.get('contains', []) # type: list[ct.SelectorContains] + self.lang = kwargs.get('lang', []) # type: list[ct.SelectorLang] + self.flags = kwargs.get('flags', 0) # type: int + self.no_match = kwargs.get('no_match', False) # type: bool - def _freeze_relations(self, relations): + def _freeze_relations(self, relations: list[_Selector]) -> ct.SelectorList: """Freeze relation.""" if relations: @@ -380,7 +395,7 @@ class _Selector(object): else: return ct.SelectorList() - def freeze(self): + def freeze(self) -> ct.Selector | ct.SelectorNull: """Freeze self.""" if self.no_match: @@ -400,7 +415,7 @@ class _Selector(object): self.flags ) - def __str__(self): # pragma: no cover + def __str__(self) -> str: # pragma: no cover """String representation.""" return ( @@ -414,14 +429,19 @@ class _Selector(object): __repr__ = __str__ -class CSSParser(object): +class CSSParser: """Parse CSS selectors.""" css_tokens = ( SelectorPattern("pseudo_close", PAT_PSEUDO_CLOSE), SpecialPseudoPattern( ( - ("pseudo_contains", (':contains',), PAT_PSEUDO_CONTAINS, SelectorPattern), + ( + "pseudo_contains", + (':contains', ':-soup-contains', ':-soup-contains-own'), + PAT_PSEUDO_CONTAINS, + SelectorPattern + ), ("pseudo_nth_child", (':nth-child', ':nth-last-child'), PAT_PSEUDO_NTH_CHILD, SelectorPattern), ("pseudo_nth_type", (':nth-of-type', ':nth-last-of-type'), PAT_PSEUDO_NTH_TYPE, SelectorPattern), ("pseudo_lang", (':lang',), PAT_PSEUDO_LANG, SelectorPattern), @@ -439,7 +459,12 @@ class CSSParser(object): SelectorPattern("combine", PAT_COMBINE) ) - def __init__(self, selector, custom=None, flags=0): + def __init__( + self, + selector: str, + custom: Optional[dict[str, str | ct.SelectorList]] = None, + flags: int = 0 + ) -> None: """Initialize.""" self.pattern = selector.replace('\x00', '\ufffd') @@ -447,7 +472,7 @@ class CSSParser(object): self.debug = self.flags & util.DEBUG self.custom = {} if custom is None else custom - def parse_attribute_selector(self, sel, m, has_selector): + def parse_attribute_selector(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Create attribute selector from the returned regex match.""" inverse = False @@ -457,22 +482,22 @@ class CSSParser(object): attr = css_unescape(m.group('attr_name')) is_type = False pattern2 = None + value = '' if case: - flags = re.I if case == 'i' else 0 + flags = (re.I if case == 'i' else 0) | re.DOTALL elif util.lower(attr) == 'type': - flags = re.I + flags = re.I | re.DOTALL is_type = True else: - flags = 0 + flags = re.DOTALL if op: if m.group('value').startswith(('"', "'")): value = css_unescape(m.group('value')[1:-1], True) else: value = css_unescape(m.group('value')) - else: - value = None + if not op: # Attribute name pattern = None @@ -517,7 +542,7 @@ class CSSParser(object): has_selector = True return has_selector - def parse_tag_pattern(self, sel, m, has_selector): + def parse_tag_pattern(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse tag pattern from regex match.""" prefix = css_unescape(m.group('tag_ns')[:-1]) if m.group('tag_ns') else None @@ -526,7 +551,7 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_class_custom(self, sel, m, has_selector): + def parse_pseudo_class_custom(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """ Parse custom pseudo class alias. @@ -538,13 +563,13 @@ class CSSParser(object): selector = self.custom.get(pseudo) if selector is None: raise SelectorSyntaxError( - "Undefined custom selector '{}' found at postion {}".format(pseudo, m.end(0)), + "Undefined custom selector '{}' found at position {}".format(pseudo, m.end(0)), self.pattern, m.end(0) ) if not isinstance(selector, ct.SelectorList): - self.custom[pseudo] = None + del self.custom[pseudo] selector = CSSParser( selector, custom=self.custom, flags=self.flags ).process_selectors(flags=FLG_PSEUDO) @@ -554,7 +579,14 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_class(self, sel, m, has_selector, iselector, is_html): + def parse_pseudo_class( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + iselector: Iterator[tuple[str, Match[str]]], + is_html: bool + ) -> tuple[bool, bool]: """Parse pseudo class.""" complex_pseudo = False @@ -642,7 +674,13 @@ class CSSParser(object): return has_selector, is_html - def parse_pseudo_nth(self, sel, m, has_selector, iselector): + def parse_pseudo_nth( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + iselector: Iterator[tuple[str, Match[str]]] + ) -> bool: """Parse `nth` pseudo.""" mdict = m.groupdict() @@ -663,29 +701,29 @@ class CSSParser(object): s2 = 1 var = True else: - nth_parts = RE_NTH.match(content) - s1 = '-' if nth_parts.group('s1') and nth_parts.group('s1') == '-' else '' + nth_parts = cast(Match[str], RE_NTH.match(content)) + _s1 = '-' if nth_parts.group('s1') and nth_parts.group('s1') == '-' else '' a = nth_parts.group('a') var = a.endswith('n') if a.startswith('n'): - s1 += '1' + _s1 += '1' elif var: - s1 += a[:-1] + _s1 += a[:-1] else: - s1 += a - s2 = '-' if nth_parts.group('s2') and nth_parts.group('s2') == '-' else '' + _s1 += a + _s2 = '-' if nth_parts.group('s2') and nth_parts.group('s2') == '-' else '' if nth_parts.group('b'): - s2 += nth_parts.group('b') + _s2 += nth_parts.group('b') else: - s2 = '0' - s1 = int(s1, 10) - s2 = int(s2, 10) + _s2 = '0' + s1 = int(_s1, 10) + s2 = int(_s2, 10) pseudo_sel = mdict['name'] if postfix == '_child': if m.group('of'): # Parse the rest of `of S`. - nth_sel = self.parse_selectors(iselector, m.end(0), FLG_PSEUDO | FLG_OPEN) + nth_sel = self.parse_selectors(iselector, m.end(0), FLG_PSEUDO | FLG_OPEN | FLG_FORGIVE) else: # Use default `*|*` for `of S`. nth_sel = CSS_NTH_OF_S_DEFAULT @@ -701,20 +739,38 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_open(self, sel, name, has_selector, iselector, index): + def parse_pseudo_open( + self, + sel: _Selector, + name: str, + has_selector: bool, + iselector: Iterator[tuple[str, Match[str]]], + index: int + ) -> bool: """Parse pseudo with opening bracket.""" flags = FLG_PSEUDO | FLG_OPEN if name == ':not': flags |= FLG_NOT - if name == ':has': - flags |= FLG_RELATIVE + elif name == ':has': + flags |= FLG_RELATIVE | FLG_FORGIVE + elif name in (':where', ':is'): + flags |= FLG_FORGIVE sel.selectors.append(self.parse_selectors(iselector, index, flags)) has_selector = True + return has_selector - def parse_has_combinator(self, sel, m, has_selector, selectors, rel_type, index): + def parse_has_combinator( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + selectors: list[_Selector], + rel_type: str, + index: int + ) -> tuple[bool, _Selector, str]: """Parse combinator tokens.""" combinator = m.group('relation').strip() @@ -723,12 +779,9 @@ class CSSParser(object): if combinator == COMMA_COMBINATOR: if not has_selector: # If we've not captured any selector parts, the comma is either at the beginning of the pattern - # or following another comma, both of which are unexpected. Commas must split selectors. - raise SelectorSyntaxError( - "The combinator '{}' at postion {}, must have a selector before it".format(combinator, index), - self.pattern, - index - ) + # or following another comma, both of which are unexpected. But shouldn't fail the pseudo-class. + sel.no_match = True + sel.rel_type = rel_type selectors[-1].relations.append(sel) rel_type = ":" + WS_COMBINATOR @@ -749,44 +802,63 @@ class CSSParser(object): self.pattern, index ) + # Set the leading combinator for the next selector. rel_type = ':' + combinator - sel = _Selector() + sel = _Selector() has_selector = False return has_selector, sel, rel_type - def parse_combinator(self, sel, m, has_selector, selectors, relations, is_pseudo, index): + def parse_combinator( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + selectors: list[_Selector], + relations: list[_Selector], + is_pseudo: bool, + is_forgive: bool, + index: int + ) -> tuple[bool, _Selector]: """Parse combinator tokens.""" combinator = m.group('relation').strip() if not combinator: combinator = WS_COMBINATOR if not has_selector: - raise SelectorSyntaxError( - "The combinator '{}' at postion {}, must have a selector before it".format(combinator, index), - self.pattern, - index - ) + if not is_forgive or combinator != COMMA_COMBINATOR: + raise SelectorSyntaxError( + "The combinator '{}' at position {}, must have a selector before it".format(combinator, index), + self.pattern, + index + ) - if combinator == COMMA_COMBINATOR: - if not sel.tag and not is_pseudo: - # Implied `*` - sel.tag = ct.SelectorTag('*', None) - sel.relations.extend(relations) - selectors.append(sel) - del relations[:] + # If we are in a forgiving pseudo class, just make the selector a "no match" + if combinator == COMMA_COMBINATOR: + sel.no_match = True + del relations[:] + selectors.append(sel) else: - sel.relations.extend(relations) - sel.rel_type = combinator - del relations[:] - relations.append(sel) - sel = _Selector() + if combinator == COMMA_COMBINATOR: + if not sel.tag and not is_pseudo: + # Implied `*` + sel.tag = ct.SelectorTag('*', None) + sel.relations.extend(relations) + selectors.append(sel) + del relations[:] + else: + sel.relations.extend(relations) + sel.rel_type = combinator + del relations[:] + relations.append(sel) + sel = _Selector() has_selector = False + return has_selector, sel - def parse_class_id(self, sel, m, has_selector): + def parse_class_id(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse HTML classes and ids.""" selector = m.group(0) @@ -797,10 +869,17 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_contains(self, sel, m, has_selector): + def parse_pseudo_contains(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse contains.""" - values = m.group('values') + pseudo = util.lower(css_unescape(m.group('name'))) + if pseudo == ":contains": + warnings.warn( + "The pseudo class ':contains' is deprecated, ':-soup-contains' should be used moving forward.", + FutureWarning + ) + contains_own = pseudo == ":-soup-contains-own" + values = css_unescape(m.group('values')) patterns = [] for token in RE_VALUES.finditer(values): if token.group('split'): @@ -811,11 +890,11 @@ class CSSParser(object): else: value = css_unescape(value) patterns.append(value) - sel.contains.append(ct.SelectorContains(tuple(patterns))) + sel.contains.append(ct.SelectorContains(patterns, contains_own)) has_selector = True return has_selector - def parse_pseudo_lang(self, sel, m, has_selector): + def parse_pseudo_lang(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse pseudo language.""" values = m.group('values') @@ -836,7 +915,7 @@ class CSSParser(object): return has_selector - def parse_pseudo_dir(self, sel, m, has_selector): + def parse_pseudo_dir(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse pseudo direction.""" value = ct.SEL_DIR_LTR if util.lower(m.group('dir')) == 'ltr' else ct.SEL_DIR_RTL @@ -844,15 +923,23 @@ class CSSParser(object): has_selector = True return has_selector - def parse_selectors(self, iselector, index=0, flags=0): + def parse_selectors( + self, + iselector: Iterator[tuple[str, Match[str]]], + index: int = 0, + flags: int = 0 + ) -> ct.SelectorList: """Parse selectors.""" + # Initialize important variables sel = _Selector() selectors = [] has_selector = False closed = False - relations = [] + relations = [] # type: list[_Selector] rel_type = ":" + WS_COMBINATOR + + # Setup various flags is_open = bool(flags & FLG_OPEN) is_pseudo = bool(flags & FLG_PSEUDO) is_relative = bool(flags & FLG_RELATIVE) @@ -863,7 +950,9 @@ class CSSParser(object): is_in_range = bool(flags & FLG_IN_RANGE) is_out_of_range = bool(flags & FLG_OUT_OF_RANGE) is_placeholder_shown = bool(flags & FLG_PLACEHOLDER_SHOWN) + is_forgive = bool(flags & FLG_FORGIVE) + # Print out useful debug stuff if self.debug: # pragma: no cover if is_pseudo: print(' is_pseudo: True') @@ -885,7 +974,10 @@ class CSSParser(object): print(' is_out_of_range: True') if is_placeholder_shown: print(' is_placeholder_shown: True') + if is_forgive: + print(' is_forgive: True') + # The algorithm for relative selectors require an initial selector in the selector list if is_relative: selectors.append(_Selector()) @@ -914,17 +1006,19 @@ class CSSParser(object): is_html = True elif key == 'pseudo_close': if not has_selector: - raise SelectorSyntaxError( - "Expected a selector at postion {}".format(m.start(0)), - self.pattern, - m.start(0) - ) + if not is_forgive: + raise SelectorSyntaxError( + "Expected a selector at position {}".format(m.start(0)), + self.pattern, + m.start(0) + ) + sel.no_match = True if is_open: closed = True break else: raise SelectorSyntaxError( - "Unmatched pseudo-class close at postion {}".format(m.start(0)), + "Unmatched pseudo-class close at position {}".format(m.start(0)), self.pattern, m.start(0) ) @@ -935,7 +1029,7 @@ class CSSParser(object): ) else: has_selector, sel = self.parse_combinator( - sel, m, has_selector, selectors, relations, is_pseudo, index + sel, m, has_selector, selectors, relations, is_pseudo, is_forgive, index ) elif key == 'attribute': has_selector = self.parse_attribute_selector(sel, m, has_selector) @@ -954,6 +1048,7 @@ class CSSParser(object): except StopIteration: pass + # Handle selectors that are not closed if is_open and not closed: raise SelectorSyntaxError( "Unclosed pseudo-class at position {}".format(index), @@ -961,6 +1056,7 @@ class CSSParser(object): index ) + # Cleanup completed selector piece if has_selector: if not sel.tag and not is_pseudo: # Implied `*` @@ -972,8 +1068,28 @@ class CSSParser(object): sel.relations.extend(relations) del relations[:] selectors.append(sel) - else: + + # Forgive empty slots in pseudo-classes that have lists (and are forgiving) + elif is_forgive: + if is_relative: + # Handle relative selectors pseudo-classes with empty slots like `:has()` + if selectors and selectors[-1].rel_type is None and rel_type == ': ': + sel.rel_type = rel_type + sel.no_match = True + selectors[-1].relations.append(sel) + has_selector = True + else: + # Handle normal pseudo-classes with empty slots + if not selectors or not relations: + # Others like `:is()` etc. + sel.no_match = True + del relations[:] + selectors.append(sel) + has_selector = True + + if not has_selector: # We will always need to finish a selector when `:has()` is used as it leads with combining. + # May apply to others as well. raise SelectorSyntaxError( 'Expected a selector at position {}'.format(index), self.pattern, @@ -994,9 +1110,10 @@ class CSSParser(object): if is_placeholder_shown: selectors[-1].flags = ct.SEL_PLACEHOLDER_SHOWN + # Return selector list return ct.SelectorList([s.freeze() for s in selectors], is_not, is_html) - def selector_iter(self, pattern): + def selector_iter(self, pattern: str) -> Iterator[tuple[str, Match[str]]]: """Iterate selector tokens.""" # Ignore whitespace and comments at start and end of pattern @@ -1037,7 +1154,7 @@ class CSSParser(object): if self.debug: # pragma: no cover print('## END PARSING') - def process_selectors(self, index=0, flags=0): + def process_selectors(self, index: int = 0, flags: int = 0) -> ct.SelectorList: """Process selectors.""" return self.parse_selectors(self.selector_iter(self.pattern), index, flags) @@ -1048,7 +1165,7 @@ class CSSParser(object): # CSS pattern for `:link` and `:any-link` CSS_LINK = CSSParser( - 'html|*:is(a, area, link)[href]' + 'html|*:is(a, area)[href]' ).process_selectors(flags=FLG_PSEUDO | FLG_HTML) # CSS pattern for `:checked` CSS_CHECKED = CSSParser( @@ -1079,23 +1196,23 @@ CSS_INDETERMINATE = CSSParser( This pattern must be at the end. Special logic is applied to the last selector. */ - html|input[type="radio"][name][name!='']:not([checked]) + html|input[type="radio"][name]:not([name='']):not([checked]) ''' ).process_selectors(flags=FLG_PSEUDO | FLG_HTML | FLG_INDETERMINATE) # CSS pattern for `:disabled` CSS_DISABLED = CSSParser( ''' - html|*:is(input[type!=hidden], button, select, textarea, fieldset, optgroup, option, fieldset)[disabled], + html|*:is(input:not([type=hidden]), button, select, textarea, fieldset, optgroup, option, fieldset)[disabled], html|optgroup[disabled] > html|option, - html|fieldset[disabled] > html|*:is(input[type!=hidden], button, select, textarea, fieldset), + html|fieldset[disabled] > html|*:is(input:not([type=hidden]), button, select, textarea, fieldset), html|fieldset[disabled] > - html|*:not(legend:nth-of-type(1)) html|*:is(input[type!=hidden], button, select, textarea, fieldset) + html|*:not(legend:nth-of-type(1)) html|*:is(input:not([type=hidden]), button, select, textarea, fieldset) ''' ).process_selectors(flags=FLG_PSEUDO | FLG_HTML) # CSS pattern for `:enabled` CSS_ENABLED = CSSParser( ''' - html|*:is(input[type!=hidden], button, select, textarea, fieldset, optgroup, option, fieldset):not(:disabled) + html|*:is(input:not([type=hidden]), button, select, textarea, fieldset, optgroup, option, fieldset):not(:disabled) ''' ).process_selectors(flags=FLG_PSEUDO | FLG_HTML) # CSS pattern for `:required` @@ -1119,8 +1236,8 @@ CSS_PLACEHOLDER_SHOWN = CSSParser( [type=email], [type=password], [type=number] - )[placeholder][placeholder!='']:is(:not([value]), [value=""]), - html|textarea[placeholder][placeholder!=''] + )[placeholder]:not([placeholder='']):is(:not([value]), [value=""]), + html|textarea[placeholder]:not([placeholder='']) ''' ).process_selectors(flags=FLG_PSEUDO | FLG_HTML | FLG_PLACEHOLDER_SHOWN) # CSS pattern default for `:nth-child` "of S" feature diff --git a/lib/soupsieve/css_types.py b/lib/soupsieve/css_types.py index 3274a3ab..fb375216 100644 --- a/lib/soupsieve/css_types.py +++ b/lib/soupsieve/css_types.py @@ -1,6 +1,8 @@ """CSS selector structure items.""" +from __future__ import annotations import copyreg -from collections.abc import Hashable, Mapping +from .pretty import pretty +from typing import Any, Iterator, Hashable, Optional, Pattern, Iterable, Mapping __all__ = ( 'Selector', @@ -29,12 +31,14 @@ SEL_DEFINED = 0x200 SEL_PLACEHOLDER_SHOWN = 0x400 -class Immutable(object): +class Immutable: """Immutable.""" - __slots__ = ('_hash',) + __slots__: tuple[str, ...] = ('_hash',) - def __init__(self, **kwargs): + _hash: int + + def __init__(self, **kwargs: Any) -> None: """Initialize.""" temp = [] @@ -45,12 +49,12 @@ class Immutable(object): super(Immutable, self).__setattr__('_hash', hash(tuple(temp))) @classmethod - def __base__(cls): + def __base__(cls) -> "type[Immutable]": """Get base class.""" return cls - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Equal.""" return ( @@ -58,7 +62,7 @@ class Immutable(object): all([getattr(other, key) == getattr(self, key) for key in self.__slots__ if key != '_hash']) ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Equal.""" return ( @@ -66,63 +70,74 @@ class Immutable(object): any([getattr(other, key) != getattr(self, key) for key in self.__slots__ if key != '_hash']) ) - def __hash__(self): + def __hash__(self) -> int: """Hash.""" return self._hash - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: """Prevent mutability.""" raise AttributeError("'{}' is immutable".format(self.__class__.__name__)) - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Representation.""" return "{}({})".format( - self.__base__(), ', '.join(["{}={!r}".format(k, getattr(self, k)) for k in self.__slots__[:-1]]) + self.__class__.__name__, ', '.join(["{}={!r}".format(k, getattr(self, k)) for k in self.__slots__[:-1]]) ) __str__ = __repr__ + def pretty(self) -> None: # pragma: no cover + """Pretty print.""" -class ImmutableDict(Mapping): + print(pretty(self)) + + +class ImmutableDict(Mapping[Any, Any]): """Hashable, immutable dictionary.""" - def __init__(self, *args, **kwargs): + def __init__( + self, + arg: dict[Any, Any] | Iterable[tuple[Any, Any]] + ) -> None: """Initialize.""" - arg = args[0] if args else kwargs - is_dict = isinstance(arg, dict) - if ( - is_dict and not all([isinstance(v, Hashable) for v in arg.values()]) or - not is_dict and not all([isinstance(k, Hashable) and isinstance(v, Hashable) for k, v in arg]) - ): - raise TypeError('All values must be hashable') - - self._d = dict(*args, **kwargs) + self._validate(arg) + self._d = dict(arg) self._hash = hash(tuple([(type(x), x, type(y), y) for x, y in sorted(self._d.items())])) - def __iter__(self): + def _validate(self, arg: dict[Any, Any] | Iterable[tuple[Any, Any]]) -> None: + """Validate arguments.""" + + if isinstance(arg, dict): + if not all([isinstance(v, Hashable) for v in arg.values()]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + elif not all([isinstance(k, Hashable) and isinstance(v, Hashable) for k, v in arg]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + + def __iter__(self) -> Iterator[Any]: """Iterator.""" return iter(self._d) - def __len__(self): + def __len__(self) -> int: """Length.""" return len(self._d) - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: """Get item: `namespace['key']`.""" + return self._d[key] - def __hash__(self): + def __hash__(self) -> int: """Hash.""" return self._hash - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Representation.""" return "{!r}".format(self._d) @@ -133,39 +148,37 @@ class ImmutableDict(Mapping): class Namespaces(ImmutableDict): """Namespaces.""" - def __init__(self, *args, **kwargs): + def __init__(self, arg: dict[str, str] | Iterable[tuple[str, str]]) -> None: """Initialize.""" - # If there are arguments, check the first index. - # `super` should fail if the user gave multiple arguments, - # so don't bother checking that. - arg = args[0] if args else kwargs - is_dict = isinstance(arg, dict) - if is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg.items()]): - raise TypeError('Namespace keys and values must be Unicode strings') - elif not is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): - raise TypeError('Namespace keys and values must be Unicode strings') + super().__init__(arg) - super(Namespaces, self).__init__(*args, **kwargs) + def _validate(self, arg: dict[str, str] | Iterable[tuple[str, str]]) -> None: + """Validate arguments.""" + + if isinstance(arg, dict): + if not all([isinstance(v, str) for v in arg.values()]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + elif not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): + raise TypeError('{} keys and values must be Unicode strings'.format(self.__class__.__name__)) class CustomSelectors(ImmutableDict): """Custom selectors.""" - def __init__(self, *args, **kwargs): + def __init__(self, arg: dict[str, str] | Iterable[tuple[str, str]]) -> None: """Initialize.""" - # If there are arguments, check the first index. - # `super` should fail if the user gave multiple arguments, - # so don't bother checking that. - arg = args[0] if args else kwargs - is_dict = isinstance(arg, dict) - if is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg.items()]): - raise TypeError('CustomSelectors keys and values must be Unicode strings') - elif not is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): - raise TypeError('CustomSelectors keys and values must be Unicode strings') + super().__init__(arg) - super(CustomSelectors, self).__init__(*args, **kwargs) + def _validate(self, arg: dict[str, str] | Iterable[tuple[str, str]]) -> None: + """Validate arguments.""" + + if isinstance(arg, dict): + if not all([isinstance(v, str) for v in arg.values()]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + elif not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): + raise TypeError('{} keys and values must be Unicode strings'.format(self.__class__.__name__)) class Selector(Immutable): @@ -176,13 +189,35 @@ class Selector(Immutable): 'relation', 'rel_type', 'contains', 'lang', 'flags', '_hash' ) + tag: Optional[SelectorTag] + ids: tuple[str, ...] + classes: tuple[str, ...] + attributes: tuple[SelectorAttribute, ...] + nth: tuple[SelectorNth, ...] + selectors: tuple[SelectorList, ...] + relation: SelectorList + rel_type: Optional[str] + contains: tuple[SelectorContains, ...] + lang: tuple[SelectorLang, ...] + flags: int + def __init__( - self, tag, ids, classes, attributes, nth, selectors, - relation, rel_type, contains, lang, flags + self, + tag: Optional[SelectorTag], + ids: tuple[str, ...], + classes: tuple[str, ...], + attributes: tuple[SelectorAttribute, ...], + nth: tuple[SelectorNth, ...], + selectors: tuple[SelectorList, ...], + relation: SelectorList, + rel_type: Optional[str], + contains: tuple[SelectorContains, ...], + lang: tuple[SelectorLang, ...], + flags: int ): """Initialize.""" - super(Selector, self).__init__( + super().__init__( tag=tag, ids=ids, classes=classes, @@ -200,10 +235,10 @@ class Selector(Immutable): class SelectorNull(Immutable): """Null Selector.""" - def __init__(self): + def __init__(self) -> None: """Initialize.""" - super(SelectorNull, self).__init__() + super().__init__() class SelectorTag(Immutable): @@ -211,13 +246,13 @@ class SelectorTag(Immutable): __slots__ = ("name", "prefix", "_hash") - def __init__(self, name, prefix): + name: str + prefix: Optional[str] + + def __init__(self, name: str, prefix: Optional[str]) -> None: """Initialize.""" - super(SelectorTag, self).__init__( - name=name, - prefix=prefix - ) + super().__init__(name=name, prefix=prefix) class SelectorAttribute(Immutable): @@ -225,10 +260,21 @@ class SelectorAttribute(Immutable): __slots__ = ("attribute", "prefix", "pattern", "xml_type_pattern", "_hash") - def __init__(self, attribute, prefix, pattern, xml_type_pattern): + attribute: str + prefix: str + pattern: Optional[Pattern[str]] + xml_type_pattern: Optional[Pattern[str]] + + def __init__( + self, + attribute: str, + prefix: str, + pattern: Optional[Pattern[str]], + xml_type_pattern: Optional[Pattern[str]] + ) -> None: """Initialize.""" - super(SelectorAttribute, self).__init__( + super().__init__( attribute=attribute, prefix=prefix, pattern=pattern, @@ -239,14 +285,15 @@ class SelectorAttribute(Immutable): class SelectorContains(Immutable): """Selector contains rule.""" - __slots__ = ("text", "_hash") + __slots__ = ("text", "own", "_hash") - def __init__(self, text): + text: tuple[str, ...] + own: bool + + def __init__(self, text: Iterable[str], own: bool) -> None: """Initialize.""" - super(SelectorContains, self).__init__( - text=text - ) + super().__init__(text=tuple(text), own=own) class SelectorNth(Immutable): @@ -254,10 +301,17 @@ class SelectorNth(Immutable): __slots__ = ("a", "n", "b", "of_type", "last", "selectors", "_hash") - def __init__(self, a, n, b, of_type, last, selectors): + a: int + n: bool + b: int + of_type: bool + last: bool + selectors: SelectorList + + def __init__(self, a: int, n: bool, b: int, of_type: bool, last: bool, selectors: SelectorList) -> None: """Initialize.""" - super(SelectorNth, self).__init__( + super().__init__( a=a, n=n, b=b, @@ -272,24 +326,24 @@ class SelectorLang(Immutable): __slots__ = ("languages", "_hash",) - def __init__(self, languages): + languages: tuple[str, ...] + + def __init__(self, languages: Iterable[str]): """Initialize.""" - super(SelectorLang, self).__init__( - languages=tuple(languages) - ) + super().__init__(languages=tuple(languages)) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Iterator.""" return iter(self.languages) - def __len__(self): # pragma: no cover + def __len__(self) -> int: # pragma: no cover """Length.""" return len(self.languages) - def __getitem__(self, index): # pragma: no cover + def __getitem__(self, index: int) -> str: # pragma: no cover """Get item.""" return self.languages[index] @@ -300,36 +354,45 @@ class SelectorList(Immutable): __slots__ = ("selectors", "is_not", "is_html", "_hash") - def __init__(self, selectors=tuple(), is_not=False, is_html=False): + selectors: tuple[Selector | SelectorNull, ...] + is_not: bool + is_html: bool + + def __init__( + self, + selectors: Optional[Iterable[Selector | SelectorNull]] = None, + is_not: bool = False, + is_html: bool = False + ) -> None: """Initialize.""" - super(SelectorList, self).__init__( - selectors=tuple(selectors), + super().__init__( + selectors=tuple(selectors) if selectors is not None else tuple(), is_not=is_not, is_html=is_html ) - def __iter__(self): + def __iter__(self) -> Iterator[Selector | SelectorNull]: """Iterator.""" return iter(self.selectors) - def __len__(self): + def __len__(self) -> int: """Length.""" return len(self.selectors) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Selector | SelectorNull: """Get item.""" return self.selectors[index] -def _pickle(p): +def _pickle(p: Any) -> Any: return p.__base__(), tuple([getattr(p, s) for s in p.__slots__[:-1]]) -def pickle_register(obj): +def pickle_register(obj: Any) -> None: """Allow object to be pickled.""" copyreg.pickle(obj, _pickle) diff --git a/lib/soupsieve/pretty.py b/lib/soupsieve/pretty.py new file mode 100644 index 00000000..4c883347 --- /dev/null +++ b/lib/soupsieve/pretty.py @@ -0,0 +1,138 @@ +""" +Format a pretty string of a `SoupSieve` object for easy debugging. + +This won't necessarily support all types and such, and definitely +not support custom outputs. + +It is mainly geared towards our types as the `SelectorList` +object is a beast to look at without some indentation and newlines. +The format and various output types is fairly known (though it +hasn't been tested extensively to make sure we aren't missing corners). + +Example: + +``` +>>> import soupsieve as sv +>>> sv.compile('this > that.class[name=value]').selectors.pretty() +SelectorList( + selectors=( + Selector( + tag=SelectorTag( + name='that', + prefix=None), + ids=(), + classes=( + 'class', + ), + attributes=( + SelectorAttribute( + attribute='name', + prefix='', + pattern=re.compile( + '^value$'), + xml_type_pattern=None), + ), + nth=(), + selectors=(), + relation=SelectorList( + selectors=( + Selector( + tag=SelectorTag( + name='this', + prefix=None), + ids=(), + classes=(), + attributes=(), + nth=(), + selectors=(), + relation=SelectorList( + selectors=(), + is_not=False, + is_html=False), + rel_type='>', + contains=(), + lang=(), + flags=0), + ), + is_not=False, + is_html=False), + rel_type=None, + contains=(), + lang=(), + flags=0), + ), + is_not=False, + is_html=False) +``` +""" +from __future__ import annotations +import re +from typing import Any + +RE_CLASS = re.compile(r'(?i)[a-z_][_a-z\d\.]+\(') +RE_PARAM = re.compile(r'(?i)[_a-z][_a-z\d]+=') +RE_EMPTY = re.compile(r'\(\)|\[\]|\{\}') +RE_LSTRT = re.compile(r'\[') +RE_DSTRT = re.compile(r'\{') +RE_TSTRT = re.compile(r'\(') +RE_LEND = re.compile(r'\]') +RE_DEND = re.compile(r'\}') +RE_TEND = re.compile(r'\)') +RE_INT = re.compile(r'\d+') +RE_KWORD = re.compile(r'(?i)[_a-z][_a-z\d]+') +RE_DQSTR = re.compile(r'"(?:\\.|[^"\\])*"') +RE_SQSTR = re.compile(r"'(?:\\.|[^'\\])*'") +RE_SEP = re.compile(r'\s*(,)\s*') +RE_DSEP = re.compile(r'\s*(:)\s*') + +TOKENS = { + 'class': RE_CLASS, + 'param': RE_PARAM, + 'empty': RE_EMPTY, + 'lstrt': RE_LSTRT, + 'dstrt': RE_DSTRT, + 'tstrt': RE_TSTRT, + 'lend': RE_LEND, + 'dend': RE_DEND, + 'tend': RE_TEND, + 'sqstr': RE_SQSTR, + 'sep': RE_SEP, + 'dsep': RE_DSEP, + 'int': RE_INT, + 'kword': RE_KWORD, + 'dqstr': RE_DQSTR +} + + +def pretty(obj: Any) -> str: # pragma: no cover + """Make the object output string pretty.""" + + sel = str(obj) + index = 0 + end = len(sel) - 1 + indent = 0 + output = [] + + while index <= end: + m = None + for k, v in TOKENS.items(): + m = v.match(sel, index) + + if m: + name = k + index = m.end(0) + if name in ('class', 'lstrt', 'dstrt', 'tstrt'): + indent += 4 + output.append('{}\n{}'.format(m.group(0), " " * indent)) + elif name in ('param', 'int', 'kword', 'sqstr', 'dqstr', 'empty'): + output.append(m.group(0)) + elif name in ('lend', 'dend', 'tend'): + indent -= 4 + output.append(m.group(0)) + elif name in ('sep',): + output.append('{}\n{}'.format(m.group(1), " " * indent)) + elif name in ('dsep',): + output.append('{} '.format(m.group(1))) + break + + return ''.join(output) diff --git a/lib/soupsieve/util.py b/lib/soupsieve/util.py index c8244b11..519c763a 100644 --- a/lib/soupsieve/util.py +++ b/lib/soupsieve/util.py @@ -1,7 +1,9 @@ """Utility.""" +from __future__ import annotations from functools import wraps, lru_cache import warnings import re +from typing import Callable, Any, Optional DEBUG = 0x00001 @@ -12,7 +14,7 @@ UC_Z = ord('Z') @lru_cache(maxsize=512) -def lower(string): +def lower(string: str) -> str: """Lower.""" new_string = [] @@ -25,7 +27,7 @@ def lower(string): class SelectorSyntaxError(Exception): """Syntax error in a CSS selector.""" - def __init__(self, msg, pattern=None, index=None): + def __init__(self, msg: str, pattern: Optional[str] = None, index: Optional[int] = None) -> None: """Initialize.""" self.line = None @@ -37,30 +39,34 @@ class SelectorSyntaxError(Exception): self.context, self.line, self.col = get_pattern_context(pattern, index) msg = '{}\n line {}:\n{}'.format(msg, self.line, self.context) - super(SelectorSyntaxError, self).__init__(msg) + super().__init__(msg) -def deprecated(message, stacklevel=2): # pragma: no cover +def deprecated(message: str, stacklevel: int = 2) -> Callable[..., Any]: # pragma: no cover """ Raise a `DeprecationWarning` when wrapped function/method is called. - Borrowed from https://stackoverflow.com/a/48632082/866026 + Usage: + + @deprecated("This method will be removed in version X; use Y instead.") + def some_method()" + pass """ - def _decorator(func): + def _wrapper(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def _func(*args, **kwargs): + def _deprecated_func(*args: Any, **kwargs: Any) -> Any: warnings.warn( - "'{}' is deprecated. {}".format(func.__name__, message), + f"'{func.__name__}' is deprecated. {message}", category=DeprecationWarning, stacklevel=stacklevel ) return func(*args, **kwargs) - return _func - return _decorator + return _deprecated_func + return _wrapper -def warn_deprecated(message, stacklevel=2): # pragma: no cover +def warn_deprecated(message: str, stacklevel: int = 2) -> None: # pragma: no cover """Warn deprecated.""" warnings.warn( @@ -70,14 +76,15 @@ def warn_deprecated(message, stacklevel=2): # pragma: no cover ) -def get_pattern_context(pattern, index): +def get_pattern_context(pattern: str, index: int) -> tuple[str, int, int]: """Get the pattern context.""" last = 0 current_line = 1 col = 1 - text = [] + text = [] # type: list[str] line = 1 + offset = None # type: Optional[int] # Split pattern by newline and handle the text before the newline for m in RE_PATTERN_LINE_SPLIT.finditer(pattern):