Source code for validate_actions.rules.expressions_contexts

"""Validates expressions against workflow contexts."""
import re
from collections.abc import Mapping, Sequence
from dataclasses import fields, is_dataclass
from difflib import SequenceMatcher
from typing import Generator, Optional

from validate_actions.domain_model.contexts import Contexts
from validate_actions.domain_model.primitives import Expression, String
from validate_actions.globals.problems import Problem, ProblemLevel
from validate_actions.rules.rule import Rule


[docs] class ExpressionsContexts(Rule): NAME = "expressions-contexts" # ==================== # MAIN VALIDATION METHODS # ====================
[docs] def check( self, ) -> Generator[Problem, None, None]: # start traversal with the global workflow contexts for ref, ctx in self._traverse(self.workflow, self.workflow.contexts): problem = self.does_expr_exist(ref, ctx) if problem: yield problem
def does_expr_exist(self, expr: Expression, contexts: Contexts) -> Optional[Problem]: # Iteratively check each part of the expression against the context tree cur = contexts parts = expr.parts or [] problem = Problem( pos=expr.pos, desc=f"Expression '{expr.string}' does not match any context", level=ProblemLevel.ERR, rule=self.NAME, ) operators = ["!", "<=", "<", ">=", ">", "==", "!=", "&&", "||"] function_regex = re.compile(r"\b([A-Za-z_][A-Za-z0-9_]*)\s*\(\s*([^)]*?)\s*\)") if any(op in expr.string for op in operators): # TODO return None if function_regex.search(expr.string): return None web_contexts_not_to_check = ["vars", "secrets", "inputs", "steps", "env"] # TODO unshelf needs and steps if not parts: return problem # If one part it is a literal if len(parts) == 1: return None parts_visited: list[String] = [] if parts[0] in web_contexts_not_to_check: return None if parts[0] == "github" and parts[1] == "event": return None for i, part in enumerate(parts): if hasattr(cur, part.string): cur = getattr(cur, part.string) elif hasattr(cur, "children_") and part.string in getattr(cur, "children_"): cur = cur.children_[part.string] elif hasattr(cur, "functions_") and part.string in getattr(cur, "functions_"): cur = getattr(cur, "functions_")[part.string] elif isinstance(cur, list) and part.string in cur: index = cur.index(part.string) cur = cur[index] else: problem.desc = ( f"Expression '{expr.string}' does not match any context. " f"Unknown property '{part.string}'" ) return self._fix_unknown_property(expr, part, cur, problem) parts_visited.append(part) return None # ==================== # UTILITY METHODS # ==================== def _traverse(self, obj, cur_context: Contexts): """ Recursively traverse AST, yielding (Expression, Contexts) pairs. Update context when encountering a node with its own 'contexts' field. """ # direct Expression: emit with current context if isinstance(obj, Expression): yield obj, cur_context return # skip walking inside the Contexts definitions themselves if isinstance(obj, Contexts): return # dataclass nodes: check for own contexts, then traverse fields if is_dataclass(obj): # switch to local context if available new_context = cur_context if hasattr(obj, "contexts") and isinstance(getattr(obj, "contexts"), Contexts): new_context = getattr(obj, "contexts") for f in fields(obj): if f.name == "contexts": # do not traverse into context definitions continue try: val = getattr(obj, f.name) except AttributeError: continue yield from self._traverse(val, new_context) return # mappings and sequences: propagate current context if isinstance(obj, Mapping): for v in obj.values(): yield from self._traverse(v, cur_context) return if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)): for item in obj: yield from self._traverse(item, cur_context) return # ==================== # FIXING METHODS # ==================== def _fix_unknown_property( self, expr: Expression, part: String, cur, problem: Problem ) -> Problem: """Fix unknown property by finding and suggesting the best match.""" field_names = [] others: list[str] = [] others_scores = {} fields_scores = {} if isinstance(cur, list): others = cur else: field_names = [f.name for f in fields(cur)] if hasattr(cur, "children_"): others = cur.children_.keys() elif hasattr(cur, "functions_"): others = list(cur.functions_.keys()) for key in field_names: score = SequenceMatcher(None, part.string, key).ratio() fields_scores[key] = score for key in others: score = SequenceMatcher(None, part.string, key).ratio() others_scores[key] = score fields_best_match = max(fields_scores.items(), key=lambda x: x[1], default=(None, 0)) others_best_match = max(others_scores.items(), key=lambda x: x[1], default=(None, 0)) fields_best_key, fields_best_score = fields_best_match others_best_key, others_best_score = others_best_match threshold = 0.8 max_key: str = "" if fields_best_score > threshold and others_best_score > threshold: candidates = [k for k in [fields_best_key, others_best_key] if k is not None] if candidates: max_key = max(candidates, key=lambda x: len(x)) else: max_key = "" elif fields_best_score > threshold: max_key = fields_best_key or "" elif others_best_score > threshold: max_key = others_best_key or "" else: return problem updated_problem_desc = ( f"Fixed '${{{{ {expr.string} }}}}': changed '{part.string}' to '{max_key}'" ) return self.fixer.edit_yaml_at_position( idx=part.pos.idx, old_text=part.string, new_text=max_key, problem=problem, new_problem_desc=updated_problem_desc, )