Skip to content

Commit

Permalink
fix: for loops no longer execute once when condition is already met (#…
Browse files Browse the repository at this point in the history
…1248)

* fix: for loop condition

This PR fixes for loops executing once when the predicate already should not be met for decrementing loops.
I have also re-implemented the codegen logic for for-loops, resulting in fewer predecessors and hopefully more
readable IR.

Resolves #1207
  • Loading branch information
mhasel authored Jun 27, 2024
1 parent c7f3d82 commit ef09b87
Show file tree
Hide file tree
Showing 13 changed files with 542 additions and 660 deletions.
199 changes: 96 additions & 103 deletions src/codegen/generators/statement_generator.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
// Copyright (c) 2020 Ghaith Hachem and Mathias Rieder
use super::{
expression_generator::{to_i1, ExpressionCodeGenerator},
expression_generator::{to_i1, ExpressionCodeGenerator, ExpressionValue},
llvm::Llvm,
};
use crate::{
codegen::debug::Debug,
codegen::{debug::DebugBuilderEnum, LlvmTypedIndex},
codegen::{
debug::{Debug, DebugBuilderEnum},
llvm_typesystem::cast_if_needed,
LlvmTypedIndex,
},
index::{ImplementationIndexEntry, Index},
resolver::{AnnotationMap, AstAnnotations, StatementAnnotation},
typesystem::DataTypeInformation,
typesystem::{get_bigger_type, DataTypeInformation, DINT_TYPE},
};
use inkwell::{
basic_block::BasicBlock,
builder::Builder,
context::Context,
values::{BasicValueEnum, FunctionValue, PointerValue},
values::{FunctionValue, PointerValue},
};
use plc_ast::{
ast::{
Expand Down Expand Up @@ -325,117 +328,107 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
body: &[AstNode],
) -> Result<(), Diagnostic> {
let (builder, current_function, context) = self.get_llvm_deps();
self.generate_assignment_statement(counter, start)?;
let condition_check = context.append_basic_block(current_function, "condition_check");
let for_body = context.append_basic_block(current_function, "for_body");
let increment_block = context.append_basic_block(current_function, "increment");
let continue_block = context.append_basic_block(current_function, "continue");

//Generate an initial jump to the for condition
builder.build_unconditional_branch(condition_check);

//Check loop condition
builder.position_at_end(condition_check);
let exp_gen = self.create_expr_generator();
let counter_statement = exp_gen.generate_expression(counter)?;

//. / and_2 \
//. / and 1 \
//. (counter_end_le && counter_start_ge) || (counter_end_ge && counter_start_le)
let or_eval = self.generate_compare_expression(counter, end, start, &exp_gen)?;
let end_ty = self.annotations.get_type_or_void(end, self.index);
let counter_ty = self.annotations.get_type_or_void(counter, self.index);
let cast_target_ty = get_bigger_type(self.index.get_type_or_panic(DINT_TYPE), counter_ty, self.index);
let cast_target_llty = self.llvm_index.find_associated_type(cast_target_ty.get_name()).unwrap();

let step_ty = by_step.as_ref().map(|it| {
self.register_debug_location(it);
self.annotations.get_type_or_void(it, self.index)
});

let eval_step = || {
step_ty.map_or_else(
|| self.llvm.create_const_numeric(&cast_target_llty, "1", SourceLocation::undefined()),
|step_ty| {
let step = exp_gen.generate_expression(by_step.as_ref().unwrap())?;
Ok(cast_if_needed!(exp_gen, cast_target_ty, step_ty, step, None))
},
)
};

builder.build_conditional_branch(to_i1(or_eval.into_int_value(), builder), for_body, continue_block);
let predicate_incrementing = context.append_basic_block(current_function, "predicate_sle");
let predicate_decrementing = context.append_basic_block(current_function, "predicate_sge");
let loop_body = context.append_basic_block(current_function, "loop");
let increment = context.append_basic_block(current_function, "increment");
let afterloop = context.append_basic_block(current_function, "continue");

//Enter the for loop
builder.position_at_end(for_body);
let body_generator = StatementCodeGenerator {
current_loop_exit: Some(continue_block),
current_loop_continue: Some(increment_block),
self.generate_assignment_statement(counter, start)?;
let counter = exp_gen.generate_lvalue(counter)?;

// generate loop predicate selector. since `STEP` can be a reference, this needs to be a runtime eval
// XXX(mhasel): IR could possibly be improved by generating phi instructions.
// Candidate for frontend optimization for builds without optimization when `STEP`
// is a compile-time constant
let is_incrementing = builder.build_int_compare(
inkwell::IntPredicate::SGT,
eval_step()?.into_int_value(),
self.llvm
.create_const_numeric(&cast_target_llty, "0", SourceLocation::undefined())?
.into_int_value(),
"is_incrementing",
);
builder.build_conditional_branch(is_incrementing, predicate_incrementing, predicate_decrementing);
// generate predicates for incrementing and decrementing counters
let generate_predicate = |predicate| {
builder.position_at_end(match predicate {
inkwell::IntPredicate::SLE => predicate_incrementing,
inkwell::IntPredicate::SGE => predicate_decrementing,
_ => unreachable!(),
});

let end = exp_gen.generate_expression_value(end).unwrap();
let end_value = match end {
ExpressionValue::LValue(ptr) => builder.build_load(ptr, ""),
ExpressionValue::RValue(val) => val,
};
let counter_value = builder.build_load(counter, "");
let cmp = builder.build_int_compare(
predicate,
cast_if_needed!(exp_gen, cast_target_ty, counter_ty, counter_value, None).into_int_value(),
cast_if_needed!(exp_gen, cast_target_ty, end_ty, end_value, None).into_int_value(),
"condition",
);
builder.build_conditional_branch(cmp, loop_body, afterloop);
};
generate_predicate(inkwell::IntPredicate::SLE);
generate_predicate(inkwell::IntPredicate::SGE);

// generate loop body
builder.position_at_end(loop_body);
let body_builder = StatementCodeGenerator {
current_loop_continue: Some(increment),
current_loop_exit: Some(afterloop),
load_prefix: self.load_prefix.clone(),
load_suffix: self.load_suffix.clone(),
..*self
};
body_generator.generate_body(body)?;
builder.build_unconditional_branch(increment_block);

//Increment
builder.position_at_end(increment_block);
let expression_generator = self.create_expr_generator();
let step_by_value = by_step.as_ref().map_or_else(
|| {
self.llvm.create_const_numeric(
&counter_statement.get_type(),
"1",
SourceLocation::undefined(),
)
},
|step| {
self.register_debug_location(step);
expression_generator.generate_expression(step)
},
)?;

let next = builder.build_int_add(
counter_statement.into_int_value(),
step_by_value.into_int_value(),
"tmpVar",
body_builder.generate_body(body)?;

// increment counter
builder.build_unconditional_branch(increment);
builder.position_at_end(increment);
let counter_value = builder.build_load(counter, "");
let inc = inkwell::values::BasicValue::as_basic_value_enum(&builder.build_int_add(
eval_step()?.into_int_value(),
cast_if_needed!(exp_gen, cast_target_ty, counter_ty, counter_value, None).into_int_value(),
"next",
));
builder.build_store(
counter,
cast_if_needed!(exp_gen, counter_ty, cast_target_ty, inc, None).into_int_value(),
);

let ptr = expression_generator.generate_lvalue(counter)?;
builder.build_store(ptr, next);

//Loop back
builder.build_unconditional_branch(condition_check);

//Continue
builder.position_at_end(continue_block);

// check condition
builder.build_conditional_branch(is_incrementing, predicate_incrementing, predicate_decrementing);
// continue
builder.position_at_end(afterloop);
Ok(())
}

fn generate_compare_expression(
&'a self,
counter: &AstNode,
end: &AstNode,
start: &AstNode,
exp_gen: &'a ExpressionCodeGenerator,
) -> Result<BasicValueEnum<'a>, Diagnostic> {
let bool_id = self.annotations.get_bool_id();
let counter_end_ge = AstFactory::create_binary_expression(
counter.clone(),
Operator::GreaterOrEqual,
end.clone(),
bool_id,
);
let counter_start_ge = AstFactory::create_binary_expression(
counter.clone(),
Operator::GreaterOrEqual,
start.clone(),
bool_id,
);
let counter_end_le = AstFactory::create_binary_expression(
counter.clone(),
Operator::LessOrEqual,
end.clone(),
bool_id,
);
let counter_start_le = AstFactory::create_binary_expression(
counter.clone(),
Operator::LessOrEqual,
start.clone(),
bool_id,
);
let and_1 =
AstFactory::create_binary_expression(counter_end_le, Operator::And, counter_start_ge, bool_id);
let and_2 =
AstFactory::create_binary_expression(counter_end_ge, Operator::And, counter_start_le, bool_id);
let or = AstFactory::create_binary_expression(and_1, Operator::Or, and_2, bool_id);

self.register_debug_location(&or);
let or_eval = exp_gen.generate_expression(&or)?;
Ok(or_eval)
}

/// genertes a case statement
///
/// CASE selector OF
Expand Down
139 changes: 139 additions & 0 deletions src/codegen/tests/code_gen_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,145 @@ fn for_statement_with_references_steps_test() {
insta::assert_snapshot!(result);
}

#[test]
fn for_statement_with_binary_expressions() {
let result = codegen(
"
PROGRAM prg
VAR
step: DINT;
x : DINT;
y : DINT;
z : DINT;
END_VAR
FOR x := y + 1 TO z - 2 BY step * 3 DO
x;
END_FOR
END_PROGRAM
",
);

insta::assert_snapshot!(result, @r###"
; ModuleID = 'main'
source_filename = "main"
%prg = type { i32, i32, i32, i32 }
@prg_instance = global %prg zeroinitializer, section "var-$RUSTY$prg_instance:r4i32i32i32i32"
define void @prg(%prg* %0) section "fn-$RUSTY$prg:v" {
entry:
%step = getelementptr inbounds %prg, %prg* %0, i32 0, i32 0
%x = getelementptr inbounds %prg, %prg* %0, i32 0, i32 1
%y = getelementptr inbounds %prg, %prg* %0, i32 0, i32 2
%z = getelementptr inbounds %prg, %prg* %0, i32 0, i32 3
%load_y = load i32, i32* %y, align 4
%tmpVar = add i32 %load_y, 1
store i32 %tmpVar, i32* %x, align 4
%load_step = load i32, i32* %step, align 4
%tmpVar1 = mul i32 %load_step, 3
%is_incrementing = icmp sgt i32 %tmpVar1, 0
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge
predicate_sle: ; preds = %increment, %entry
%load_z = load i32, i32* %z, align 4
%tmpVar2 = sub i32 %load_z, 2
%1 = load i32, i32* %x, align 4
%condition = icmp sle i32 %1, %tmpVar2
br i1 %condition, label %loop, label %continue
predicate_sge: ; preds = %increment, %entry
%load_z3 = load i32, i32* %z, align 4
%tmpVar4 = sub i32 %load_z3, 2
%2 = load i32, i32* %x, align 4
%condition5 = icmp sge i32 %2, %tmpVar4
br i1 %condition5, label %loop, label %continue
loop: ; preds = %predicate_sge, %predicate_sle
%load_x = load i32, i32* %x, align 4
br label %increment
increment: ; preds = %loop
%3 = load i32, i32* %x, align 4
%load_step6 = load i32, i32* %step, align 4
%tmpVar7 = mul i32 %load_step6, 3
%next = add i32 %tmpVar7, %3
store i32 %next, i32* %x, align 4
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge
continue: ; preds = %predicate_sge, %predicate_sle
ret void
}
"###);
}

#[test]
fn for_statement_type_casting() {
let result = codegen(
"FUNCTION main
VAR
a: USINT;
b: INT := 1;
END_VAR
FOR a := 0 TO 10 BY b DO
b := b * 3;
END_FOR
END_FUNCTION",
);
insta::assert_snapshot!(result, @r###"
; ModuleID = 'main'
source_filename = "main"
define void @main() section "fn-$RUSTY$main:v" {
entry:
%a = alloca i8, align 1
%b = alloca i16, align 2
store i8 0, i8* %a, align 1
store i16 1, i16* %b, align 2
store i8 0, i8* %a, align 1
%load_b = load i16, i16* %b, align 2
%0 = trunc i16 %load_b to i8
%1 = sext i8 %0 to i32
%is_incrementing = icmp sgt i32 %1, 0
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge
predicate_sle: ; preds = %increment, %entry
%2 = load i8, i8* %a, align 1
%3 = zext i8 %2 to i32
%condition = icmp sle i32 %3, 10
br i1 %condition, label %loop, label %continue
predicate_sge: ; preds = %increment, %entry
%4 = load i8, i8* %a, align 1
%5 = zext i8 %4 to i32
%condition1 = icmp sge i32 %5, 10
br i1 %condition1, label %loop, label %continue
loop: ; preds = %predicate_sge, %predicate_sle
%load_b2 = load i16, i16* %b, align 2
%6 = sext i16 %load_b2 to i32
%tmpVar = mul i32 %6, 3
%7 = trunc i32 %tmpVar to i16
store i16 %7, i16* %b, align 2
br label %increment
increment: ; preds = %loop
%8 = load i8, i8* %a, align 1
%load_b3 = load i16, i16* %b, align 2
%9 = trunc i16 %load_b3 to i8
%10 = sext i8 %9 to i32
%11 = zext i8 %8 to i32
%next = add i32 %10, %11
%12 = trunc i32 %next to i8
store i8 %12, i8* %a, align 1
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge
continue: ; preds = %predicate_sge, %predicate_sle
ret void
}
"###);
}

#[test]
fn while_statement() {
let result = codegen(
Expand Down
Loading

0 comments on commit ef09b87

Please sign in to comment.