#!/usr/bin/env python3 import re import json import logging import argparse import xml.etree.ElementTree as ET from rich import print as rprint from rich.logging import RichHandler # Set logger handler = RichHandler(rich_tracebacks=True, markup=True) logger = logging.getLogger("WebXmlParser") logger.addHandler(handler) class WebXmlParser: """ This class is the one used to parse a web.xml file """ DEFAULT_UNKNOWN_TAG = "(Unknown)" def __init__(self, xml): self.xml = xml self.ns = None self.display_name = None self.json_parsed = None @staticmethod def url_patterns_match(filter_pattern, servlet_pattern): # Normalize patterns f = filter_pattern.strip() s = servlet_pattern.strip() # Exact match or machall pattern if(f == s or f in ("/", "/*")): return True # filter path prefix match (e.g. /foo/*) if f.endswith("/*"): prefix = f[:-2] # servlet_pattern should start with this prefix + / if s.startswith(prefix) and (s == prefix or s[len(prefix)] == '/' or s.startswith(prefix + "/")): return True # filter extension match (*.jsp) if f.startswith("*."): ext = f[1:] return s.endswith(ext) return False def tag(self, name): """ Return the tag associated to a name and a namespace if any """ return f"{{{self.ns}}}{name}" if self.ns else name def get_servlets(self, root): """ Recover the servlets :rtype: Dict[str, Dict[str, str]] :return: A dictionnary of all servlets indexed by their name """ servlets = {} for servlet in root.findall(self.tag("servlet")): name = servlet.findtext(self.tag("servlet-name"), default=self.__class__.DEFAULT_UNKNOWN_TAG).strip() clazz = servlet.findtext(self.tag("servlet-class"), default="").strip() jsp_file = servlet.findtext(self.tag("jsp-file"), default="").strip() servlets[name] = {"class": clazz, "jsp_file": jsp_file, "url_mapping": []} return servlets def get_filters(self, root): """ Recover the filter in use on the application """ filters = {} for servlet_filter in root.findall(self.tag("filter")): name = servlet_filter.findtext(self.tag("filter-name"), default=self.__class__.DEFAULT_UNKNOWN_TAG).strip() clazz = servlet_filter.findtext(self.tag("filter-class"), default="").strip() filters[name] = {"class": clazz, "url_mapping": []} return filters def get_mapping(self, root, entities, entity_name): """ Recover the mapping from each entities (servlet or filter) """ for mapping in root.findall(self.tag(f"{entity_name}-mapping")): name = mapping.findtext(self.tag(f"{entity_name}-name"), default=self.__class__.DEFAULT_UNKNOWN_TAG) for pattern in mapping.findall(self.tag("url-pattern")): entities[name]["url_mapping"].append(pattern.text) def mix_filter_to_servlet(self, servlets, filters): """ Add filters to the servlets they applied by checking the url-pattern """ def update_filter(result, filter_pattern, servlet_pattern): for f_pattern in filter_pattern: for s_pattern in servlet_pattern: if WebXmlParser.url_patterns_match(f_pattern, s_pattern): if not result[servlet_name].get("filters"): result[servlet_name]["filters"] = {} result[servlet_name]["filters"][filter_name] = filter_attrs return result = {} for servlet_name, servlet_attrs in servlets.items(): # Initialize the result result[servlet_name] = servlet_attrs servlet_patterns = servlet_attrs["url_mapping"] for filter_name, filter_attrs in filters.items(): filter_patterns = filter_attrs["url_mapping"] update_filter(result, filter_patterns, servlet_patterns) return result def parse(self): """ Parses self.xml """ try: logger.info(f"Parsing '{self.xml}'") tree = ET.parse(self.xml) root = tree.getroot() except Exception as e: logger.critical(f"Error while parsing {self.xml}: {e}") return if root.tag.startswith("{"): self.ns = root.tag[root.tag.find("{")+1 : root.tag.find("}")] display_name = root.find(self.tag("display-name")) if display_name is not None: self.display_name = display_name.text logger.info(f"Application name '{self.display_name}'") else: self.display_name = self.__class__.DEFAULT_UNKNOWN_TAG servlets = self.get_servlets(root) self.get_mapping(root, servlets, "servlet") logger.info(f"Recovered {len(servlets.keys())} servlets") filters = self.get_filters(root) self.get_mapping(root, filters, "filter") logger.info(f"Recovered {len(filters.keys())} filters") self.json_parsed = self.mix_filter_to_servlet(servlets, filters) def display(self, json_format=False): """ Display in a specific format which can be json or a pretty display for terminal """ if not self.json_parsed: return if json_format: print(json.dumps(self.json_parsed, indent=4)) else: indent = " "*2 for servlet_name, attrs in self.json_parsed.items(): rprint(f"[b]Servlet:[/b] [green]{servlet_name}[/green]") if attrs.get("jsp_file"): rprint(f"{indent}[b]JSP file:[/b] [blue]{attrs.get('jsp_file')}[/blue]") else: rprint(f"{indent}[b]Class:[/b] [blue]{attrs.get('class')}[/blue]") rprint(f"{indent}[b]Urls:[/b]") for url_pattern in attrs["url_mapping"]: rprint(f"{indent} - [red]{url_pattern}[/red]") if attrs.get("filters"): rprint(f"{indent}[b]Filters:[/b]") for filter_name, filter_attrs in attrs["filters"].items(): rprint(f"{indent} - [light_sea_green]{filter_name}[/light_sea_green] ({filter_attrs.get('class')})") else: rprint(f"{indent}[b]No Filters[/b]") rprint() def options(): """ Parse cli options """ parser = argparse.ArgumentParser(description="Parse a tomcat web.xml file") parser.add_argument("-x", "--xml", nargs="+", required=False, help="Path to web.xml file (default: %(default)s)", default=["web.xml"]) parser.add_argument("-v", "--verbose", help="Increase verbosity", action="store_true") parser.add_argument("-j", "--json", help="Display output in json", action="store_true") return parser.parse_args() if __name__ == "__main__": args = options() logger.setLevel(logging.DEBUG if args.verbose else logging.WARNING) for web_xml in args.xml: xml_parser = WebXmlParser(web_xml) xml_parser.parse() xml_parser.display(json_format=args.json)