Skip to content

Commit

Permalink
enable visit for CompilationUnit and its members
Browse files Browse the repository at this point in the history
the visitor can now be called on the CompilationUnit
and it will visit all GlobalVariables, all Pous, all
UserTypes and Implementations
  • Loading branch information
riederm committed Jun 2, 2024
1 parent 4776e10 commit 8cb832c
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 19 deletions.
200 changes: 199 additions & 1 deletion compiler/plc_ast/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,63 @@ pub trait Walker {
/// }
/// ```
pub trait AstVisitor: Sized {
/// Walks through an `AstNode` and applies the visitor's `walk` method to each node.
/// visits through an `AstNode` and applies the visitor's `walk` method to each node.
fn visit(&mut self, node: &AstNode) {
node.walk(self)
}

/// Visits a `CompilationUnit` node.
/// Make sure to call `walk` on the `CompilationUnit` node to visit its children.
fn visit_compilation_unit(&mut self, unit: &CompilationUnit) {
unit.walk(self)
}

/// Visits an `Implementation` node.
/// Make sure to call `walk` on the `Implementation` node to visit its children.
fn visit_implementation(&mut self, implementation: &Implementation) {
implementation.walk(self);
}

/// Visits a `DataTypeDeclaration` node.
/// Make sure to call `walk` on the `VariableBlock` node to visit its children.
fn visit_variable_block(&mut self, block: &VariableBlock) {
block.walk(self)
}

/// Visits a `Variable` node.
/// Make sure to call `walk` on the `Variable` node to visit its children.
fn visit_variable(&mut self, variable: &Variable) {
variable.walk(self);
}

fn visit_enum_element(&mut self, element: &AstNode) {
element.walk(self);
}

/// Visits a `DataTypeDeclaration` node.
/// Make sure to call `walk` on the `DataTypeDeclaration` node to visit its children.
fn visit_data_type_declaration(&mut self, data_type_declaration: &DataTypeDeclaration) {
data_type_declaration.walk(self);
}

/// Visits a `UserTypeDeclaration` node.
/// Make sure to call `walk` on the `UserTypeDeclaration` node to visit its children.
fn visit_user_type_declaration(&mut self, user_type: &UserTypeDeclaration) {
user_type.walk(self);
}

/// Visits a `UserTypeDeclaration` node.
/// Make sure to call `walk` on the `DataType` node to visit its children.
fn visit_data_type(&mut self, data_type: &DataType) {
data_type.walk(self);
}

/// Visits a `Pou` node.
/// Make sure to call `walk` on the `Pou` node to visit its children.
fn visit_pou(&mut self, pou: &Pou) {
pou.walk(self);
}

/// Visits an `EmptyStatement` node.
/// Make sure to call `walk` on the `EmptyStatement` node to visit its children.
fn visit_empty_statement(&mut self, _stmt: &EmptyStatement, _node: &AstNode) {}
Expand Down Expand Up @@ -96,6 +148,7 @@ pub trait AstVisitor: Sized {
}

/// Visits an `Identifier` node.
/// Make sure to call `walk` on the `Identifier` node to visit its children.
fn visit_identifier(&mut self, _stmt: &str, _node: &AstNode) {}

/// Visits a `DirectAccess` node.
Expand Down Expand Up @@ -427,3 +480,148 @@ impl Walker for AstNode {
}
}
}

impl Walker for CompilationUnit {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
for block in &self.global_vars {
visitor.visit_variable_block(block);
}

for user_type in &self.user_types {
visitor.visit_user_type_declaration(user_type);
}

for pou in &self.units {
visitor.visit_pou(pou);
}

for i in &self.implementations {
visitor.visit_implementation(i);
}
}
}

impl Walker for UserTypeDeclaration {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
visitor.visit_data_type(&self.data_type);
visit_all_nodes!(visitor, &self.initializer);
}
}

impl Walker for VariableBlock {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
for v in self.variables.iter() {
visitor.visit_variable(v);
}
}
}

impl Walker for Variable {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
visit_all_nodes!(visitor, &self.address);
visitor.visit_data_type_declaration(&self.data_type_declaration);
visit_all_nodes!(visitor, &self.initializer);
}
}

impl Walker for DataType {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
match self {
DataType::StructType { variables, .. } => {
for v in variables.iter() {
visitor.visit_variable(v);
}
}
DataType::EnumType { elements, .. } => {
for ele in flatten_expression_list(elements) {
visitor.visit_enum_element(ele);
}
}
DataType::SubRangeType { bounds, .. } => {
visit_all_nodes!(visitor, bounds);
}
DataType::ArrayType { bounds, referenced_type, .. } => {
visitor.visit(bounds);
visitor.visit_data_type_declaration(referenced_type);
}
DataType::PointerType { referenced_type, .. } => {
visitor.visit_data_type_declaration(referenced_type);
}
DataType::StringType { size, .. } => {
visit_all_nodes!(visitor, size);
}
DataType::VarArgs { referenced_type, .. } => {
if let Some(data_type_declaration) = referenced_type {
visitor.visit_data_type_declaration(data_type_declaration);
}
}
DataType::GenericType { .. } => {
//no further visits
}
}
}
}

impl Walker for DataTypeDeclaration {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
if let DataTypeDeclaration::DataTypeDefinition { data_type, .. } = self {
visitor.visit_data_type(data_type);
}
}
}

impl<T> Walker for Option<T>
where
T: Walker,
{
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
if let Some(node) = self {
node.walk(visitor);
}
}
}

impl Walker for Pou {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
for block in &self.variable_blocks {
visitor.visit_variable_block(block);
}

self.return_type.as_ref().inspect(|rt| visitor.visit_data_type_declaration(rt));
}
}

impl Walker for Implementation {
fn walk<V>(&self, visitor: &mut V)
where
V: AstVisitor,
{
for n in &self.statements {
visitor.visit(n);
}
}
}
76 changes: 58 additions & 18 deletions src/parser/tests/ast_visitor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,21 @@ fn get_character_range(start: char, end: char) -> Vec<String> {
}

fn collect_identifiers(src: &str) -> IdentifierCollector {
let mut visitor = IdentifierCollector::default();
visit(src, &mut visitor);
visitor.identifiers.sort();
visitor
}

fn visit(src: &str, visitor: &mut impl AstVisitor) {
let id_provider = IdProvider::default();
let (compilation_unit, _) = parser::parse(
lexer::lex_with_ids(src, id_provider.clone(), SourceLocationFactory::internal(src)),
LinkageType::Internal,
"test.st",
);

let mut visitor = IdentifierCollector::default();

for st in &compilation_unit.implementations[0].statements {
visitor.visit(st);
}
visitor.identifiers.sort();
visitor
visitor.visit_compilation_unit(&compilation_unit)
}

#[test]
Expand Down Expand Up @@ -216,24 +217,63 @@ fn test_visit_return_statement() {
assert_eq!(get_character_range('a', 'b'), visitor.identifiers);
}

struct AssignmentCounter {
count: usize,
}

impl AstVisitor for AssignmentCounter {
fn visit_assignment(&mut self, stmt: &plc_ast::ast::Assignment, _node: &plc_ast::ast::AstNode) {
self.count += 1;
stmt.walk(self)
#[test]
fn test_visit_data_type_declaration() {
struct FieldCollector {
fields: Vec<String>,
}

fn visit_output_assignment(&mut self, stmt: &plc_ast::ast::Assignment, _node: &plc_ast::ast::AstNode) {
self.count += 1;
stmt.walk(self)
impl AstVisitor for FieldCollector {
fn visit_variable(&mut self, variable: &plc_ast::ast::Variable) {
self.fields.push(variable.name.clone());
variable.walk(self);
}

fn visit_enum_element(&mut self, element: &plc_ast::ast::AstNode) {
if let Some(name) = element.get_flat_reference_name() {
self.fields.push(name.to_string());
}
element.walk(self);
}
}
let mut visitor = FieldCollector { fields: vec![] };

visit(
"TYPE myStruct: STRUCT
a, b, c: DINT;
s: STRING;
e: (enum1, enum2, enum3);
END_STRUCT;
",
&mut visitor,
);

visitor.fields.sort();
assert_eq!(vec!["a", "b", "c", "e", "enum1", "enum2", "enum3", "s"], visitor.fields);
}

#[test]
fn test_count_assignments() {
struct AssignmentCounter {
count: usize,
}

impl AstVisitor for AssignmentCounter {
fn visit_assignment(&mut self, stmt: &plc_ast::ast::Assignment, _node: &plc_ast::ast::AstNode) {
self.count += 1;
stmt.walk(self)
}

fn visit_output_assignment(
&mut self,
stmt: &plc_ast::ast::Assignment,
_node: &plc_ast::ast::AstNode,
) {
self.count += 1;
stmt.walk(self)
}
}

let id_provider = IdProvider::default();
let (compilation_unit, _) = parser::parse(
lexer::lex_with_ids(
Expand Down

0 comments on commit 8cb832c

Please sign in to comment.