diff --git a/CLAUDE.md b/CLAUDE.md index 6703e93..9755d5c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -87,7 +87,7 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). - **Kernel Parameters**: Declared with `PARAM `, each becomes a `memref` function argument with `forth.param_name` attribute. Using a param name in code pushes its byte address onto the stack via `forth.param_ref` - **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer - **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index f4b8b4a..5d36f64 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -640,6 +640,51 @@ def Forth_GtOp : Forth_Op<"gt", [Pure]> { }]; } +def Forth_NeOp : Forth_Op<"ne", [Pure]> { + let summary = "Test inequality of top two stack elements"; + let description = [{ + Pops two values, pushes -1 (true) if not equal, 0 (false) otherwise. + Forth semantics: ( a b -- flag ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + +def Forth_LeOp : Forth_Op<"le", [Pure]> { + let summary = "Test less-than-or-equal of top two stack elements"; + let description = [{ + Pops two values, pushes -1 (true) if a <= b, 0 (false) otherwise. + Forth semantics: ( a b -- flag ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + +def Forth_GeOp : Forth_Op<"ge", [Pure]> { + let summary = "Test greater-than-or-equal of top two stack elements"; + let description = [{ + Pops two values, pushes -1 (true) if a >= b, 0 (false) otherwise. + Forth semantics: ( a b -- flag ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + def Forth_ZeroEqOp : Forth_Op<"zero_eq", [Pure]> { let summary = "Test if top of stack is zero"; let description = [{ @@ -660,15 +705,18 @@ def Forth_ZeroEqOp : Forth_Op<"zero_eq", [Pure]> { //===----------------------------------------------------------------------===// def Forth_YieldOp : Forth_Op<"yield", [Pure, Terminator, ReturnLike, - ParentOneOf<["IfOp", "BeginUntilOp", "DoLoopOp"]>]> { + ParentOneOf<["IfOp", "BeginUntilOp", "BeginWhileRepeatOp", "DoLoopOp"]>]> { let summary = "Yield stack from control flow region"; let description = [{ Yields the current stack state from a control flow region back to the parent operation. Acts as a region terminator. + When the optional `while_cond` attribute is present, the yield acts as + a WHILE condition (continue when flag is non-zero) rather than + UNTIL (exit when flag is non-zero). }]; - let arguments = (ins Forth_StackType:$result); + let arguments = (ins Forth_StackType:$result, OptionalAttr:$while_cond); let assemblyFormat = [{ - $result attr-dict `:` type($result) + $result (`while_cond` $while_cond^)? attr-dict `:` type($result) }]; } @@ -718,6 +766,24 @@ def Forth_DoLoopOp : Forth_Op<"do_loop", [RecursiveMemoryEffects, let hasCustomAssemblyFormat = 1; } +def Forth_BeginWhileRepeatOp : Forth_Op<"begin_while_repeat", + [RecursiveMemoryEffects, + DeclareOpInterfaceMethods]> { + let summary = "Pre-test loop (BEGIN/WHILE/REPEAT)"; + let description = [{ + BEGIN/WHILE/REPEAT loop. The condition region runs first, WHILE pops flag. + If flag is non-zero, the body region executes and loops back to condition. + If flag is zero, the loop exits. + Stack effect: ( -- ) with flag consumed each iteration. + }]; + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + let regions = (region SizedRegion<1>:$condition_region, + SizedRegion<1>:$body_region); + let hasCustomAssemblyFormat = 1; +} + def Forth_LoopIndexOp : Forth_Op<"loop_index", [Pure]> { let summary = "Push loop index onto stack (I word)"; let description = [{ Only valid inside a forth.do_loop body. ( -- i ) }]; diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index 6539ecc..00ca8e1 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -488,7 +488,7 @@ struct BinaryCmpOpConversion : public OpConversionPattern { // Compare Value cmp = rewriter.create(loc, predicate, a, b); - // Extend i1 to i64: true → -1 (all bits set), false → 0 + // Extend i1 to i64: true = -1 (all bits set), false = 0 Value result = rewriter.create(loc, rewriter.getI64Type(), cmp); @@ -508,6 +508,12 @@ using LtOpConversion = BinaryCmpOpConversion; using GtOpConversion = BinaryCmpOpConversion; +using NeOpConversion = + BinaryCmpOpConversion; +using LeOpConversion = + BinaryCmpOpConversion; +using GeOpConversion = + BinaryCmpOpConversion; /// Conversion pattern for forth.not operation (bitwise NOT). /// Unary: pops one value, XORs with -1 (all bits set), pushes result: (a -- ~a) @@ -564,7 +570,7 @@ struct ZeroEqOpConversion : public OpConversionPattern { Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, a, zero); - // Extend i1 to i64: true → -1, false → 0 + // Extend i1 to i64: true = -1, false = 0 Value result = rewriter.create(loc, rewriter.getI64Type(), cmp); @@ -777,8 +783,9 @@ struct GlobalIdOpConversion : public OpConversionPattern { }; /// Conversion pattern for forth.yield operation. -/// Context-aware: inside scf.while's `before` region (from BeginUntilOp), -/// emits flag-pop + scf.condition; otherwise emits scf.yield with SP. +/// Context-aware: inside scf.while's `before` region (from BeginUntilOp or +/// BeginWhileRepeatOp), emits flag-pop + scf.condition; otherwise emits +/// scf.yield with SP. struct YieldOpConversion : public OpConversionPattern { YieldOpConversion(const TypeConverter &typeConverter, MLIRContext *context) : OpConversionPattern(typeConverter, context) {} @@ -801,11 +808,17 @@ struct YieldOpConversion : public OpConversionPattern { Value one = rewriter.create(loc, 1); Value spAfterPop = rewriter.create(loc, sp, one); - // UNTIL exits on non-zero; scf.while continues on true. - // So keep going when flag == 0. Value zero = rewriter.create(loc, 0, 64); - Value keepGoing = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, zero); + Value keepGoing; + if (op.getWhileCond()) { + // WHILE semantics: continue when flag is non-zero. + keepGoing = rewriter.create( + loc, arith::CmpIPredicate::ne, flag, zero); + } else { + // UNTIL semantics: exit on non-zero; keep going when flag == 0. + keepGoing = rewriter.create( + loc, arith::CmpIPredicate::eq, flag, zero); + } rewriter.replaceOpWithNewOp(op, keepGoing, ValueRange{spAfterPop}); @@ -849,7 +862,7 @@ struct IfOpConversion : public OpConversionPattern { /*addElseBlock=*/true); // Convert block signatures and inline regions into scf.if. - // convertRegionTypes converts !forth.stack block arg → {memref, index} + // convertRegionTypes converts !forth.stack block arg to {memref, index} // and inserts tracked materializations (unrealized_conversion_cast). // We inline into scf.if and mergeBlocks to substitute the converted // block args with parent-scope values. The original materialization @@ -905,7 +918,7 @@ struct BeginUntilOpConversion auto whileOp = rewriter.create(loc, TypeRange{indexType}, ValueRange{stackPtr}); - // --- Before region (body): convert + inline --- + // Before region (body): convert + inline. Region &bodyRegion = op.getBodyRegion(); if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter()))) return failure(); @@ -923,7 +936,7 @@ struct BeginUntilOpConversion Value beforeSP = newBeforeBlock->getArgument(0); rewriter.mergeBlocks(&beforeBlock, newBeforeBlock, {memref, beforeSP}); - // --- After region (identity): just yield the SP --- + // After region (identity): just yield the SP. Block *afterBlock = rewriter.createBlock(&whileOp.getAfter()); afterBlock->addArgument(indexType, loc); Value afterSP = afterBlock->getArgument(0); @@ -936,6 +949,65 @@ struct BeginUntilOpConversion } }; +/// Conversion pattern for forth.begin_while_repeat operation. +/// Creates scf.while with the condition as the `before` region, +/// and the body as the `after` region. +struct BeginWhileRepeatOpConversion + : public OpConversionPattern { + BeginWhileRepeatOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) { + } + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::BeginWhileRepeatOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + auto indexType = rewriter.getIndexType(); + + // Create scf.while with index result, stackPtr as iter arg. + auto whileOp = rewriter.create(loc, TypeRange{indexType}, + ValueRange{stackPtr}); + + // Before region (condition): convert + inline. + Region &condRegion = op.getConditionRegion(); + if (failed(rewriter.convertRegionTypes(&condRegion, *getTypeConverter()))) + return failure(); + + rewriter.inlineRegionBefore(condRegion, whileOp.getBefore(), + whileOp.getBefore().end()); + + Block &beforeBlock = whileOp.getBefore().front(); + Block *newBeforeBlock = rewriter.createBlock(&whileOp.getBefore()); + newBeforeBlock->addArgument(indexType, loc); + Value beforeSP = newBeforeBlock->getArgument(0); + rewriter.mergeBlocks(&beforeBlock, newBeforeBlock, {memref, beforeSP}); + + // After region (body): convert + inline. + Region &bodyRegion = op.getBodyRegion(); + if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter()))) + return failure(); + + rewriter.inlineRegionBefore(bodyRegion, whileOp.getAfter(), + whileOp.getAfter().end()); + + Block &afterBlock = whileOp.getAfter().front(); + Block *newAfterBlock = rewriter.createBlock(&whileOp.getAfter()); + newAfterBlock->addArgument(indexType, loc); + Value afterSP = newAfterBlock->getArgument(0); + rewriter.mergeBlocks(&afterBlock, newAfterBlock, {memref, afterSP}); + + // Replace forth.begin_while_repeat with {memref, whileOp result}. + rewriter.replaceOpWithMultiple(op, {{memref, whileOp.getResult(0)}}); + return success(); + } +}; + /// Conversion pattern for forth.do_loop operation. /// Pops start and limit from the stack, creates scf.for from start to limit. struct DoLoopOpConversion : public OpConversionPattern { @@ -962,7 +1034,7 @@ struct DoLoopOpConversion : public OpConversionPattern { Value limitI64 = rewriter.create(loc, memref, spAfterStart); Value spAfterPops = rewriter.create(loc, spAfterStart, one); - // Cast i64 → index for scf.for bounds + // Cast i64 to index for scf.for bounds Value startIdx = rewriter.create(loc, indexType, startI64); Value limitIdx = @@ -1023,7 +1095,7 @@ struct LoopIndexOpConversion : public OpConversionPattern { if (!forOp) return rewriter.notifyMatchFailure(op, "not inside an scf.for"); - // Get induction variable and cast index → i64 + // Get induction variable and cast index to i64 Value iv = forOp.getInductionVar(); Value ivI64 = rewriter.create(loc, rewriter.getI64Type(), iv); @@ -1090,10 +1162,12 @@ struct ConvertForthToMemRefPass AddOpConversion, SubOpConversion, MulOpConversion, DivOpConversion, ModOpConversion, AndOpConversion, OrOpConversion, XorOpConversion, NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion, - LtOpConversion, GtOpConversion, ZeroEqOpConversion, - ParamRefOpConversion, LoadOpConversion, StoreOpConversion, - IfOpConversion, BeginUntilOpConversion, DoLoopOpConversion, - LoopIndexOpConversion, YieldOpConversion>(typeConverter, context); + LtOpConversion, GtOpConversion, NeOpConversion, LeOpConversion, + GeOpConversion, ZeroEqOpConversion, ParamRefOpConversion, + LoadOpConversion, StoreOpConversion, IfOpConversion, + BeginUntilOpConversion, BeginWhileRepeatOpConversion, + DoLoopOpConversion, LoopIndexOpConversion, YieldOpConversion>( + typeConverter, context); // Add GPU indexing op conversion patterns patterns.add>(typeConverter, diff --git a/lib/Dialect/Forth/ForthDialect.cpp b/lib/Dialect/Forth/ForthDialect.cpp index 687b720..5521f64 100644 --- a/lib/Dialect/Forth/ForthDialect.cpp +++ b/lib/Dialect/Forth/ForthDialect.cpp @@ -134,6 +134,76 @@ ParseResult BeginUntilOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +//===----------------------------------------------------------------------===// +// BeginWhileRepeatOp RegionBranchOpInterface. +//===----------------------------------------------------------------------===// + +void BeginWhileRepeatOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (point.isParent()) { + // From parent: enter the condition region. + regions.push_back(RegionSuccessor(&getConditionRegion(), + getConditionRegion().getArguments())); + return; + } + if (point.getRegionOrNull() == &getConditionRegion()) { + // From condition: enter body or exit to parent. + regions.push_back( + RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); + regions.push_back(RegionSuccessor(getOperation()->getResults())); + return; + } + // From body: loop back to condition. + regions.push_back(RegionSuccessor(&getConditionRegion(), + getConditionRegion().getArguments())); +} + +OperandRange +BeginWhileRepeatOp::getEntrySuccessorOperands(RegionBranchPoint point) { + return getOperation()->getOperands(); +} + +//===----------------------------------------------------------------------===// +// BeginWhileRepeatOp custom assembly format. +//===----------------------------------------------------------------------===// + +void BeginWhileRepeatOp::print(OpAsmPrinter &p) { + p << ' ' << getInputStack() << " : " << getInputStack().getType() << " -> " + << getOutputStack().getType() << ' '; + p.printRegion(getConditionRegion()); + p << " do "; + p.printRegion(getBodyRegion()); +} + +ParseResult BeginWhileRepeatOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand inputStack; + Type inputType, outputType; + + if (parser.parseOperand(inputStack) || parser.parseColon() || + parser.parseType(inputType) || parser.parseArrow() || + parser.parseType(outputType) || + parser.resolveOperand(inputStack, inputType, result.operands)) + return failure(); + + result.addTypes(outputType); + + // Parse condition region. + auto *condRegion = result.addRegion(); + if (parser.parseRegion(*condRegion)) + return failure(); + + // Parse "do" keyword and body region. + if (parser.parseKeyword("do")) + return failure(); + + auto *bodyRegion = result.addRegion(); + if (parser.parseRegion(*bodyRegion)) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // DoLoopOp RegionBranchOpInterface. //===----------------------------------------------------------------------===// diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index ef012cf..c2a5d8c 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -347,6 +347,12 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, return builder.create(loc, stackType, inputStack).getResult(); } else if (word == ">") { return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "<>") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "<=") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == ">=") { + return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "0=") { return builder.create(loc, stackType, inputStack) .getResult(); @@ -407,7 +413,10 @@ ForthParser::parseBody(Value &stack, } else if (currentToken.text == "BEGIN") { Location tokenLoc = getLoc(); consume(); // consume BEGIN - stack = parseBeginUntil(stack, tokenLoc); + if (isWhileLoop()) + stack = parseBeginWhileRepeat(stack, tokenLoc); + else + stack = parseBeginUntil(stack, tokenLoc); if (!stack) return failure(); } else if (currentToken.text == "DO") { @@ -461,7 +470,7 @@ Value ForthParser::parseIf(Value inputStack, Location loc) { builder.create(loc, stackType, thenArg).getResult(); if (failed(parseBody(thenStack, isElseOrThen))) return nullptr; - builder.create(getLoc(), thenStack); + builder.create(getLoc(), thenStack, /*while_cond=*/nullptr); // --- Else region --- Block *elseBlock = new Block(); @@ -476,14 +485,14 @@ Value ForthParser::parseIf(Value inputStack, Location loc) { builder.create(loc, stackType, elseArg).getResult(); if (failed(parseBody(elseStack, isThen))) return nullptr; - builder.create(getLoc(), elseStack); + builder.create(getLoc(), elseStack, /*while_cond=*/nullptr); } else { // No ELSE clause — just drop the flag and yield (identity). builder.setInsertionPointToStart(elseBlock); Value elseArg = elseBlock->getArgument(0); Value elseStack = builder.create(loc, stackType, elseArg).getResult(); - builder.create(loc, elseStack); + builder.create(loc, elseStack, /*while_cond=*/nullptr); } // Consume THEN. @@ -520,7 +529,7 @@ Value ForthParser::parseBeginUntil(Value inputStack, Location loc) { Value bodyStack = bodyBlock->getArgument(0); if (failed(parseBody(bodyStack, isUntil))) return nullptr; - builder.create(getLoc(), bodyStack); + builder.create(getLoc(), bodyStack, /*while_cond=*/nullptr); // Consume UNTIL. if (currentToken.kind != Token::Kind::Word || currentToken.text != "UNTIL") { @@ -534,6 +543,96 @@ Value ForthParser::parseBeginUntil(Value inputStack, Location loc) { return beginUntilOp.getOutputStack(); } +//===----------------------------------------------------------------------===// +// BEGIN / WHILE / REPEAT lookahead + parsing. +//===----------------------------------------------------------------------===// + +bool ForthParser::isWhileLoop() { + // Save lexer position and current token. + const char *savedPos = lexer.getPosition(); + Token savedToken = currentToken; + + int depth = 0; + while (currentToken.kind != Token::Kind::EndOfFile) { + if (currentToken.kind == Token::Kind::Word) { + if (currentToken.text == "BEGIN" || currentToken.text == "DO") + ++depth; + else if (depth == 0 && currentToken.text == "UNTIL") { + // Found UNTIL at our nesting level → not a WHILE loop. + lexer.setPosition(savedPos); + currentToken = savedToken; + return false; + } else if (depth == 0 && currentToken.text == "WHILE") { + // Found WHILE at our nesting level → is a WHILE loop. + lexer.setPosition(savedPos); + currentToken = savedToken; + return true; + } else if (currentToken.text == "UNTIL" || currentToken.text == "LOOP" || + currentToken.text == "REPEAT") + --depth; + } + consume(); + } + + // Reached EOF without finding UNTIL or WHILE — restore and return false. + lexer.setPosition(savedPos); + currentToken = savedToken; + return false; +} + +Value ForthParser::parseBeginWhileRepeat(Value inputStack, Location loc) { + Type stackType = forth::StackType::get(context); + + // Create forth.begin_while_repeat op. + auto bwrOp = + builder.create(loc, stackType, inputStack); + + auto isWhile = [](StringRef word) { return word == "WHILE"; }; + auto isRepeat = [](StringRef word) { return word == "REPEAT"; }; + + // --- Condition region --- + Block *condBlock = new Block(); + condBlock->addArgument(stackType, loc); + bwrOp.getConditionRegion().push_back(condBlock); + + builder.setInsertionPointToStart(condBlock); + Value condStack = condBlock->getArgument(0); + if (failed(parseBody(condStack, isWhile))) + return nullptr; + // Terminate with forth.yield {while_cond} to indicate WHILE semantics. + builder.create(getLoc(), condStack, + /*while_cond=*/builder.getUnitAttr()); + + // Consume WHILE. + if (currentToken.kind != Token::Kind::Word || currentToken.text != "WHILE") { + (void)emitError("expected 'WHILE'"); + return nullptr; + } + consume(); // consume WHILE + + // --- Body region --- + Block *bodyBlock = new Block(); + bodyBlock->addArgument(stackType, loc); + bwrOp.getBodyRegion().push_back(bodyBlock); + + builder.setInsertionPointToStart(bodyBlock); + Value bodyStack = bodyBlock->getArgument(0); + if (failed(parseBody(bodyStack, isRepeat))) + return nullptr; + builder.create(getLoc(), bodyStack, /*while_cond=*/nullptr); + + // Consume REPEAT. + if (currentToken.kind != Token::Kind::Word || currentToken.text != "REPEAT") { + (void)emitError("expected 'REPEAT'"); + return nullptr; + } + consume(); // consume REPEAT + + // Restore insertion point to after the forth.begin_while_repeat op. + builder.setInsertionPointAfter(bwrOp); + return bwrOp.getOutputStack(); +} + //===----------------------------------------------------------------------===// // DO / LOOP parsing. //===----------------------------------------------------------------------===// @@ -559,7 +658,7 @@ Value ForthParser::parseDoLoop(Value inputStack, Location loc) { return nullptr; } --doLoopDepth; - builder.create(getLoc(), bodyStack); + builder.create(getLoc(), bodyStack, /*while_cond=*/nullptr); // Consume LOOP. if (currentToken.kind != Token::Kind::Word || currentToken.text != "LOOP") { diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.h b/lib/Translation/ForthToMLIR/ForthToMLIR.h index 33a2d6e..0369f0a 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.h +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.h @@ -47,6 +47,10 @@ class ForthLexer { /// Reset lexer to beginning of buffer. void reset(); + /// Save/restore lexer position for lookahead. + const char *getPosition() const { return curPtr; } + void setPosition(const char *pos) { curPtr = pos; } + private: llvm::SourceMgr &sourceMgr; unsigned bufferID; @@ -112,6 +116,13 @@ class ForthParser { /// Parse a BEGIN/UNTIL loop, creating a forth.begin_until op. Value parseBeginUntil(Value inputStack, Location loc); + /// Parse a BEGIN/WHILE/REPEAT loop, creating a forth.begin_while_repeat op. + Value parseBeginWhileRepeat(Value inputStack, Location loc); + + /// Lookahead: is the current BEGIN a WHILE loop (vs UNTIL)? + /// Saves and restores lexer position. + bool isWhileLoop(); + /// Parse a DO/LOOP counted loop, creating a forth.do_loop op. Value parseDoLoop(Value inputStack, Location loc); diff --git a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir new file mode 100644 index 0000000..58978bc --- /dev/null +++ b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir @@ -0,0 +1,44 @@ +// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s + +// CHECK-LABEL: func.func private @main + +// Verify scf.while with index iter arg: +// CHECK: scf.while (%{{.*}} = %{{.*}}) : (index) -> index { + +// Condition region: operations + flag pop + condition (ne for WHILE) +// CHECK: memref.load +// CHECK: arith.cmpi sgt +// CHECK: arith.extsi +// CHECK: memref.load +// CHECK: arith.subi +// CHECK: arith.cmpi ne +// CHECK: scf.condition(%{{.*}}) %{{.*}} : index + +// Body region: operations + yield +// CHECK: } do { +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: memref.load +// CHECK: arith.subi +// CHECK: scf.yield %{{.*}} : index +// CHECK: } + +module { + func.func private @main() { + %0 = forth.stack !forth.stack + %1 = forth.literal %0 10 : !forth.stack -> !forth.stack + %2 = forth.begin_while_repeat %1 : !forth.stack -> !forth.stack { + ^bb0(%arg0: !forth.stack): + %3 = forth.dup %arg0 : !forth.stack -> !forth.stack + %4 = forth.literal %3 0 : !forth.stack -> !forth.stack + %5 = forth.gt %4 : !forth.stack -> !forth.stack + forth.yield %5 while_cond : !forth.stack + } do { + ^bb0(%arg1: !forth.stack): + %6 = forth.literal %arg1 1 : !forth.stack -> !forth.stack + %7 = forth.sub %6 : !forth.stack -> !forth.stack + forth.yield %7 : !forth.stack + } + return + } +} diff --git a/test/Conversion/ForthToMemRef/comparison.mlir b/test/Conversion/ForthToMemRef/comparison.mlir index cf3e0d5..2a03b13 100644 --- a/test/Conversion/ForthToMemRef/comparison.mlir +++ b/test/Conversion/ForthToMemRef/comparison.mlir @@ -31,6 +31,27 @@ // CHECK: arith.extsi %{{.*}} : i1 to i64 // CHECK: memref.store +// ne: load two values, arith.cmpi ne, extsi to i64, store +// CHECK: memref.load +// CHECK: memref.load +// CHECK: arith.cmpi ne, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.extsi %{{.*}} : i1 to i64 +// CHECK: memref.store + +// le: load two values, arith.cmpi sle, extsi to i64, store +// CHECK: memref.load +// CHECK: memref.load +// CHECK: arith.cmpi sle, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.extsi %{{.*}} : i1 to i64 +// CHECK: memref.store + +// ge: load two values, arith.cmpi sge, extsi to i64, store +// CHECK: memref.load +// CHECK: memref.load +// CHECK: arith.cmpi sge, %{{.*}}, %{{.*}} : i64 +// CHECK: arith.extsi %{{.*}} : i1 to i64 +// CHECK: memref.store + module { func.func private @main() { %0 = forth.stack !forth.stack @@ -45,6 +66,15 @@ module { %9 = forth.gt %8 : !forth.stack -> !forth.stack %10 = forth.literal %9 0 : !forth.stack -> !forth.stack %11 = forth.zero_eq %10 : !forth.stack -> !forth.stack + %12 = forth.literal %11 7 : !forth.stack -> !forth.stack + %13 = forth.literal %12 8 : !forth.stack -> !forth.stack + %14 = forth.ne %13 : !forth.stack -> !forth.stack + %15 = forth.literal %14 9 : !forth.stack -> !forth.stack + %16 = forth.literal %15 10 : !forth.stack -> !forth.stack + %17 = forth.le %16 : !forth.stack -> !forth.stack + %18 = forth.literal %17 11 : !forth.stack -> !forth.stack + %19 = forth.literal %18 12 : !forth.stack -> !forth.stack + %20 = forth.ge %19 : !forth.stack -> !forth.stack return } } diff --git a/test/Pipeline/begin-while-repeat.forth b/test/Pipeline/begin-while-repeat.forth new file mode 100644 index 0000000..b19b9c8 --- /dev/null +++ b/test/Pipeline/begin-while-repeat.forth @@ -0,0 +1,15 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-scf-to-cf --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID + +\ Verify that BEGIN/WHILE/REPEAT through the full pipeline produces a gpu.binary +\ CHECK: gpu.binary @warpforth_module + +\ Verify intermediate MLIR: gpu.func with conditional branch +\ MID: gpu.module @warpforth_module +\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: cf.br +\ MID: cf.cond_br +\ MID: gpu.return + +PARAM DATA 4 +10 BEGIN DUP 0 > WHILE 1 - REPEAT DATA 0 CELLS + ! diff --git a/test/Translation/Forth/begin-while-repeat.forth b/test/Translation/Forth/begin-while-repeat.forth new file mode 100644 index 0000000..e8aea5d --- /dev/null +++ b/test/Translation/Forth/begin-while-repeat.forth @@ -0,0 +1,20 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ Verify BEGIN/WHILE/REPEAT parsing produces forth.begin_while_repeat +\ with condition and body regions + +\ CHECK: %[[S0:.*]] = forth.stack +\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] 10 +\ CHECK: %[[LOOP:.*]] = forth.begin_while_repeat %[[S1]] +\ CHECK: ^bb0(%[[CARG:.*]]: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.literal +\ CHECK: forth.gt +\ CHECK: forth.yield %{{.*}} while_cond +\ CHECK: } do { +\ CHECK: ^bb0(%[[BARG:.*]]: !forth.stack): +\ CHECK: forth.literal +\ CHECK: forth.sub +\ CHECK: forth.yield +\ CHECK: } +10 BEGIN DUP 0 > WHILE 1 - REPEAT diff --git a/test/Translation/Forth/comparison-ops.forth b/test/Translation/Forth/comparison-ops.forth index 14ea0da..5e5746b 100644 --- a/test/Translation/Forth/comparison-ops.forth +++ b/test/Translation/Forth/comparison-ops.forth @@ -12,5 +12,14 @@ \ CHECK: %[[S8:.*]] = forth.literal %[[S7]] \ CHECK: %[[S9:.*]] = forth.gt %[[S8]] \ CHECK: %[[S10:.*]] = forth.literal %[[S9]] -\ CHECK: %{{.*}} = forth.zero_eq %[[S10]] -1 2 = 3 4 < 5 6 > 0 0= +\ CHECK: %[[S11:.*]] = forth.zero_eq %[[S10]] +\ CHECK: %[[S12:.*]] = forth.literal %[[S11]] +\ CHECK: %[[S13:.*]] = forth.literal %[[S12]] +\ CHECK: %[[S14:.*]] = forth.ne %[[S13]] +\ CHECK: %[[S15:.*]] = forth.literal %[[S14]] +\ CHECK: %[[S16:.*]] = forth.literal %[[S15]] +\ CHECK: %[[S17:.*]] = forth.le %[[S16]] +\ CHECK: %[[S18:.*]] = forth.literal %[[S17]] +\ CHECK: %[[S19:.*]] = forth.literal %[[S18]] +\ CHECK: %{{.*}} = forth.ge %[[S19]] +1 2 = 3 4 < 5 6 > 0 0= 7 8 <> 9 10 <= 11 12 >=