Skip to content

Commit

Permalink
review feedback(aggregate return lowering) (#1397)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhasel authored Jan 28, 2025
1 parent e731987 commit 6c4a761
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 103 deletions.
12 changes: 6 additions & 6 deletions compiler/plc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ impl Pou {
}

pub fn calc_return_name(pou_name: &str) -> &str {
pou_name.split('.').last().unwrap_or_default()
pou_name.rsplit_once('.').map(|(_, return_name)| return_name).unwrap_or(pou_name)
}

pub fn is_aggregate(&self) -> bool {
Expand Down Expand Up @@ -317,7 +317,7 @@ impl PouType {
}
}

pub fn is_function_or_init(&self) -> bool {
pub fn is_function_method_or_init(&self) -> bool {
matches!(self, PouType::Function | PouType::Init | PouType::ProjectInit | PouType::Method { .. })
}
}
Expand Down Expand Up @@ -538,7 +538,7 @@ impl From<&DataTypeDeclaration> for SourceLocation {
impl DataTypeDeclaration {
pub fn get_name(&self) -> Option<&str> {
match self {
Self::Aggregate { referenced_type, .. }
DataTypeDeclaration::Aggregate { referenced_type, .. }
| DataTypeDeclaration::DataTypeReference { referenced_type, .. } => {
Some(referenced_type.as_str())
}
Expand All @@ -550,7 +550,7 @@ impl DataTypeDeclaration {
match self {
DataTypeDeclaration::DataTypeReference { location, .. } => location.clone(),
DataTypeDeclaration::DataTypeDefinition { location, .. } => location.clone(),
Self::Aggregate { location, .. } => location.clone(),
DataTypeDeclaration::Aggregate { location, .. } => location.clone(),
}
}

Expand All @@ -570,12 +570,12 @@ impl DataTypeDeclaration {

None
}
Self::Aggregate { .. } => None,
DataTypeDeclaration::Aggregate { .. } => None,
}
}

pub fn is_aggregate(&self) -> bool {
matches!(self, Self::Aggregate { .. })
matches!(self, DataTypeDeclaration::Aggregate { .. })
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/codegen/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ impl<'ink> DebugBuilder<'ink> {
}

let implementation = pou.find_implementation(index).expect("A POU will have an impl at this stage");
if !implementation.get_implementation_type().is_function_or_init() {
if !implementation.get_implementation_type().is_function_method_or_init() {
self.register_struct_parameter(pou, func);
} else {
let declared_params = index.get_declared_parameters(implementation.get_call_name());
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/generators/pou_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> {
}
_ => {
dti.map(|it| {
if !implementation.get_implementation_type().is_function_or_init() {
if !implementation.get_implementation_type().is_function_method_or_init() {
return *p;
}
// for aggregate function parameters we will generate a pointer instead of the value type.
Expand Down Expand Up @@ -354,7 +354,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> {
&self,
implementation: &ImplementationIndexEntry,
) -> Result<Vec<BasicMetadataTypeEnum<'ink>>, Diagnostic> {
if !implementation.implementation_type.is_function_or_init() {
if !implementation.implementation_type.is_function_method_or_init() {
let mut parameters = vec![];
let instance_struct_type: StructType = self
.llvm_index
Expand Down Expand Up @@ -456,7 +456,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> {
}

// generate local variables
if implementation.pou_type.is_function_or_init() {
if implementation.pou_type.is_function_method_or_init() {
self.generate_local_function_arguments_accessors(
&mut local_index,
&implementation.type_name,
Expand Down
10 changes: 5 additions & 5 deletions src/codegen/generators/statement_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct StatementCodeGenerator<'a, 'b> {
llvm: &'b Llvm<'a>,
index: &'b Index,
annotations: &'b AstAnnotations,
llvm_inex: &'b LlvmTypedIndex<'a>,
llvm_index: &'b LlvmTypedIndex<'a>,
function_context: &'b FunctionContext<'a, 'b>,

pub load_prefix: String,
Expand All @@ -65,15 +65,15 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
llvm: &'b Llvm<'a>,
index: &'b Index,
annotations: &'b AstAnnotations,
llvm_inex: &'b LlvmTypedIndex<'a>,
llvm_index: &'b LlvmTypedIndex<'a>,
linking_context: &'b FunctionContext<'a, 'b>,
debug: &'b DebugBuilderEnum<'a>,
) -> StatementCodeGenerator<'a, 'b> {
StatementCodeGenerator {
llvm,
index,
annotations,
llvm_inex,
llvm_index,
function_context: linking_context,
load_prefix: "load_".to_string(),
load_suffix: "".to_string(),
Expand All @@ -100,7 +100,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {

/// generates a list of statements
pub fn generate_body(&self, statements: &[AstNode]) -> Result<(), Diagnostic> {
let mut child_index = LlvmTypedIndex::create_child(self.llvm_inex);
let mut child_index = LlvmTypedIndex::create_child(self.llvm_index);
for s in statements {
child_index = self.generate_statement(child_index, s)?;
}
Expand Down Expand Up @@ -783,7 +783,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
let var_name = format!("{call_name}_ret"); // TODO: Naming convention (see plc_util/src/convention.rs)
let ret_name = ret_v.get_qualified_name();
let value_ptr =
self.llvm_inex.find_loaded_associated_variable_value(ret_name).ok_or_else(|| {
self.llvm_index.find_loaded_associated_variable_value(ret_name).ok_or_else(|| {
Diagnostic::codegen_error(
format!("Cannot generate return variable for {call_name:}"),
SourceLocation::undefined(),
Expand Down
2 changes: 0 additions & 2 deletions src/codegen/tests/code_gen_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,6 @@ fn fb_method_with_var_input_defaults() {
insta::assert_snapshot!(prg);
}

//A test for a method with an initialized input variable
#[test]
fn method_codegen_with_initialized_input() {
let prg = codegen(
Expand All @@ -1027,7 +1026,6 @@ fn method_codegen_with_initialized_input() {
insta::assert_snapshot!(prg);
}

//A test for a method with multiple input variables
#[test]
fn method_codegen_with_multiple_input() {
let prg = codegen(
Expand Down
4 changes: 1 addition & 3 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,7 @@ impl From<&PouType> for ImplementationType {
}

impl ImplementationType {
// TODO: this now also takes methods into accounts, find a better name
pub fn is_function_or_init(&self) -> bool {
pub fn is_function_method_or_init(&self) -> bool {
matches!(
self,
ImplementationType::Function
Expand Down Expand Up @@ -762,7 +761,6 @@ impl PouIndexEntry {
match self {
PouIndexEntry::Program { instance_struct_name, .. }
| PouIndexEntry::FunctionBlock { instance_struct_name, .. }
// | PouIndexEntry::Method { instance_struct_name, .. }
| PouIndexEntry::Action { instance_struct_name, .. }
| PouIndexEntry::Class { instance_struct_name, .. } => Some(instance_struct_name.as_str()),
_ => None, //functions have no struct type
Expand Down
126 changes: 76 additions & 50 deletions src/lowering/calls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
//! Changes the calls to aggregate return types
//! to make them VAR_IN_OUT calls, allowing them
//! to be called from C_APIs and simplifying code generation
//!
//! As a first step, the POU signature is changed. E.g. a function
//! returning a `STRING` will now return `__VOID` with the return variable
//! being moved into a `VAR_IN_OUT` block:
//! ```iec61131
//! // user code
//! FUNCTION foo : STRING
//! VAR_INPUT
//! a: DINT;
//! END_VAR
//! END_FUNCTION
//! ```
//! ```iec61131
//! // lowered equivalent
//! FUNCTION foo
//! VAR_IN_OUT
//! foo: STRING;
//! END_VAR
//! VAR_INPUT
//! a: DINT;
//! END_VAR
//! END_FUNCTION
//! ```
//!
//! Next, every call-statement to that POU has it's arguments updated, with a temporary
//! variable being allocated to hold the value.
//! Locally allocated variables follow a naming-scheme of `__<function_name><number>`,
//! <number> being a value from an atomically incremented counter to avoid naming conflicts
//! (the same approach is used for allocated variables in LLVM-IR).
//! ```iec61131
//! // user code. Let `s` be a variable of type `STRING`
//! // ...
//! s := foo(42);
//! // ...
//! ```
//! ```iec61131
//! // lowered equivalent
//! // ...
//! alloca __foo1 : STRING;
//! foo(__foo1, 42);
//! s := __foo1;
//! // ...
//! ```
use std::{borrow::BorrowMut, sync::atomic::AtomicI32};

Expand Down Expand Up @@ -145,37 +188,43 @@ impl AstVisitorMut for AggregateTypeLowerer {
return;
}
let index = self.index.as_ref().expect("Can't get here without an index");
//Check if pou has a return type
if let Some(return_var) = pou.return_type.take() {
let name = return_var.get_name().expect("We should have names at this point");
let location = return_var.get_location();

// Check if POU has a return type
let Some(return_type_name) = pou
.return_type
.as_ref()
.map(|it| it.get_name().expect("We should have names at this point").to_string())
else {
return;
};

// If the return type is aggregate, remove it from the signature and add a matching variable
// in a VAR_IN_OUT block
if index.get_effective_type_or_void_by_name(&return_type_name).is_aggregate_type() {
let original_return = pou.return_type.take().unwrap();
let location = original_return.get_location();
//Create a new return type for the pou
pou.return_type.replace(plc_ast::ast::DataTypeDeclaration::Aggregate {
referenced_type: name.to_string(),
referenced_type: return_type_name,
location,
});
let data_type = index.get_effective_type_or_void_by_name(name);
if data_type.is_aggregate_type() {
//Insert a new in out var to the pou variable block declarations
let block = VariableBlock {
access: AccessModifier::Public,
constant: false,
retain: false,
variables: vec![Variable {
name: pou.get_return_name().to_string(),
data_type_declaration: return_var,
initializer: None,
address: None,
location: pou.name_location.clone(),
}],
variable_block_type: VariableBlockType::InOut,
linkage: LinkageType::Internal,
location: SourceLocation::internal(),
};
pou.variable_blocks.insert(0, block)
} else {
pou.return_type.replace(return_var);
}
//Insert a new in out var to the pou variable block declarations
let block = VariableBlock {
access: AccessModifier::Public,
constant: false,
retain: false,
variables: vec![Variable {
name: pou.get_return_name().to_string(),
data_type_declaration: original_return,
initializer: None,
address: None,
location: pou.name_location.clone(),
}],
variable_block_type: VariableBlockType::InOut,
linkage: LinkageType::Internal,
location: SourceLocation::internal(),
};
pou.variable_blocks.insert(0, block)
}
}

Expand Down Expand Up @@ -449,29 +498,6 @@ mod tests {
assert_debug_snapshot!(lowerer.index.unwrap().find_pou_type("fb.complexMethod").unwrap());
}

// Are we in a call?
// foo(x:= baz()); callStatement -> Reference baz_1
// foo(x:= baz()); callStatement -> Reference baz_2
// foo(x:= baz()); callStatement -> Reference baz_3
// foo(x:= baz()); callStatement -> Reference
// foo(x:= baz()); callStatement -> Reference
// foo(x:= baz()); callStatement -> Reference
// foo(x:= baz()); callStatement -> Reference
// alloca temp;
// baz(temp);
// foo(x := temp)
//
// call -> alloc, call, ref
//
// Insert alloca _before_ the call statement
// x := foo();
// alloca temp
// foo(temp);
// x := temp;
//Check right, if a function call with aggregate, add allocation
//fix call
//assign to allocation
//
#[test]
fn simple_call_statement() {
let id_provider = IdProvider::default();
Expand Down
27 changes: 2 additions & 25 deletions src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! records all resulting types associated with the statement's id.
use rustc_hash::{FxHashMap, FxHashSet};
use std::{any::Any, fmt::Debug, hash::Hash};
use std::{fmt::Debug, hash::Hash};

use plc_ast::{
ast::{
Expand Down Expand Up @@ -554,7 +554,7 @@ impl Dependency {
}
}

pub trait AnnotationMap: ToAny {
pub trait AnnotationMap {
fn get(&self, s: &AstNode) -> Option<&StatementAnnotation>;

fn get_hint(&self, s: &AstNode) -> Option<&StatementAnnotation>;
Expand Down Expand Up @@ -639,29 +639,6 @@ pub struct AstAnnotations {
bool_annotation: StatementAnnotation,
}

pub trait ToAny: 'static {
fn as_any(&mut self) -> &mut dyn Any;
}

impl<T: AnnotationMap> ToAny for T {
fn as_any(&mut self) -> &mut dyn Any {
self
}
}

impl AstAnnotations {
pub fn from_dyn(mut annotation_map: Box<dyn AnnotationMap>, bool_id: AstId) -> Self {
let it: &mut dyn Any = annotation_map.as_any();
if let Some(map) = it.downcast_mut::<AstAnnotations>().map(std::mem::take) {
return map;
}
let annotation_map =
it.downcast_mut::<AnnotationMapImpl>().map(std::mem::take).expect("AnnotationMapImpl");

Self::new(annotation_map, bool_id)
}
}

impl AnnotationMap for AstAnnotations {
fn get(&self, s: &AstNode) -> Option<&StatementAnnotation> {
if s.get_id() == self.bool_id {
Expand Down
Loading

0 comments on commit 6c4a761

Please sign in to comment.