Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Kernel Parameters**: Declared with `PARAM <name> <size>`, each becomes a `memref<Nxi64>` function argument with `forth.param_name` attribute. Using a param name in code pushes its byte address onto the stack via `forth.param_ref`
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
Expand Down
72 changes: 69 additions & 3 deletions include/warpforth/Dialect/Forth/ForthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,51 @@ def Forth_GtOp : Forth_Op<"gt", [Pure]> {
}];
}

def Forth_NeOp : Forth_Op<"ne", [Pure]> {
let summary = "Test inequality of top two stack elements";
let description = [{
Pops two values, pushes -1 (true) if not equal, 0 (false) otherwise.
Forth semantics: ( a b -- flag )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

def Forth_LeOp : Forth_Op<"le", [Pure]> {
let summary = "Test less-than-or-equal of top two stack elements";
let description = [{
Pops two values, pushes -1 (true) if a <= b, 0 (false) otherwise.
Forth semantics: ( a b -- flag )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

def Forth_GeOp : Forth_Op<"ge", [Pure]> {
let summary = "Test greater-than-or-equal of top two stack elements";
let description = [{
Pops two values, pushes -1 (true) if a >= b, 0 (false) otherwise.
Forth semantics: ( a b -- flag )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

def Forth_ZeroEqOp : Forth_Op<"zero_eq", [Pure]> {
let summary = "Test if top of stack is zero";
let description = [{
Expand All @@ -660,15 +705,18 @@ def Forth_ZeroEqOp : Forth_Op<"zero_eq", [Pure]> {
//===----------------------------------------------------------------------===//

def Forth_YieldOp : Forth_Op<"yield", [Pure, Terminator, ReturnLike,
ParentOneOf<["IfOp", "BeginUntilOp", "DoLoopOp"]>]> {
ParentOneOf<["IfOp", "BeginUntilOp", "BeginWhileRepeatOp", "DoLoopOp"]>]> {
let summary = "Yield stack from control flow region";
let description = [{
Yields the current stack state from a control flow region back to
the parent operation. Acts as a region terminator.
When the optional `while_cond` attribute is present, the yield acts as
a WHILE condition (continue when flag is non-zero) rather than
UNTIL (exit when flag is non-zero).
}];
let arguments = (ins Forth_StackType:$result);
let arguments = (ins Forth_StackType:$result, OptionalAttr<UnitAttr>:$while_cond);
let assemblyFormat = [{
$result attr-dict `:` type($result)
$result (`while_cond` $while_cond^)? attr-dict `:` type($result)
}];
}

Expand Down Expand Up @@ -718,6 +766,24 @@ def Forth_DoLoopOp : Forth_Op<"do_loop", [RecursiveMemoryEffects,
let hasCustomAssemblyFormat = 1;
}

def Forth_BeginWhileRepeatOp : Forth_Op<"begin_while_repeat",
[RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>]> {
let summary = "Pre-test loop (BEGIN/WHILE/REPEAT)";
let description = [{
BEGIN/WHILE/REPEAT loop. The condition region runs first, WHILE pops flag.
If flag is non-zero, the body region executes and loops back to condition.
If flag is zero, the loop exits.
Stack effect: ( -- ) with flag consumed each iteration.
}];
let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);
let regions = (region SizedRegion<1>:$condition_region,
SizedRegion<1>:$body_region);
let hasCustomAssemblyFormat = 1;
}

def Forth_LoopIndexOp : Forth_Op<"loop_index", [Pure]> {
let summary = "Push loop index onto stack (I word)";
let description = [{ Only valid inside a forth.do_loop body. ( -- i ) }];
Expand Down
108 changes: 91 additions & 17 deletions lib/Conversion/ForthToMemRef/ForthToMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ struct BinaryCmpOpConversion : public OpConversionPattern<ForthOp> {
// Compare
Value cmp = rewriter.create<arith::CmpIOp>(loc, predicate, a, b);

// Extend i1 to i64: true -1 (all bits set), false 0
// Extend i1 to i64: true = -1 (all bits set), false = 0
Value result =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), cmp);

Expand All @@ -508,6 +508,12 @@ using LtOpConversion =
BinaryCmpOpConversion<forth::LtOp, arith::CmpIPredicate::slt>;
using GtOpConversion =
BinaryCmpOpConversion<forth::GtOp, arith::CmpIPredicate::sgt>;
using NeOpConversion =
BinaryCmpOpConversion<forth::NeOp, arith::CmpIPredicate::ne>;
using LeOpConversion =
BinaryCmpOpConversion<forth::LeOp, arith::CmpIPredicate::sle>;
using GeOpConversion =
BinaryCmpOpConversion<forth::GeOp, arith::CmpIPredicate::sge>;

/// Conversion pattern for forth.not operation (bitwise NOT).
/// Unary: pops one value, XORs with -1 (all bits set), pushes result: (a -- ~a)
Expand Down Expand Up @@ -564,7 +570,7 @@ struct ZeroEqOpConversion : public OpConversionPattern<forth::ZeroEqOp> {
Value cmp =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);

// Extend i1 to i64: true -1, false 0
// Extend i1 to i64: true = -1, false = 0
Value result =
rewriter.create<arith::ExtSIOp>(loc, rewriter.getI64Type(), cmp);

Expand Down Expand Up @@ -777,8 +783,9 @@ struct GlobalIdOpConversion : public OpConversionPattern<forth::GlobalIdOp> {
};

/// Conversion pattern for forth.yield operation.
/// Context-aware: inside scf.while's `before` region (from BeginUntilOp),
/// emits flag-pop + scf.condition; otherwise emits scf.yield with SP.
/// Context-aware: inside scf.while's `before` region (from BeginUntilOp or
/// BeginWhileRepeatOp), emits flag-pop + scf.condition; otherwise emits
/// scf.yield with SP.
struct YieldOpConversion : public OpConversionPattern<forth::YieldOp> {
YieldOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<forth::YieldOp>(typeConverter, context) {}
Expand All @@ -801,11 +808,17 @@ struct YieldOpConversion : public OpConversionPattern<forth::YieldOp> {
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value spAfterPop = rewriter.create<arith::SubIOp>(loc, sp, one);

// UNTIL exits on non-zero; scf.while continues on true.
// So keep going when flag == 0.
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
Value keepGoing = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, flag, zero);
Value keepGoing;
if (op.getWhileCond()) {
// WHILE semantics: continue when flag is non-zero.
keepGoing = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, flag, zero);
} else {
// UNTIL semantics: exit on non-zero; keep going when flag == 0.
keepGoing = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, flag, zero);
}

rewriter.replaceOpWithNewOp<scf::ConditionOp>(op, keepGoing,
ValueRange{spAfterPop});
Expand Down Expand Up @@ -849,7 +862,7 @@ struct IfOpConversion : public OpConversionPattern<forth::IfOp> {
/*addElseBlock=*/true);

// Convert block signatures and inline regions into scf.if.
// convertRegionTypes converts !forth.stack block arg {memref, index}
// convertRegionTypes converts !forth.stack block arg to {memref, index}
// and inserts tracked materializations (unrealized_conversion_cast).
// We inline into scf.if and mergeBlocks to substitute the converted
// block args with parent-scope values. The original materialization
Expand Down Expand Up @@ -905,7 +918,7 @@ struct BeginUntilOpConversion
auto whileOp = rewriter.create<scf::WhileOp>(loc, TypeRange{indexType},
ValueRange{stackPtr});

// --- Before region (body): convert + inline ---
// Before region (body): convert + inline.
Region &bodyRegion = op.getBodyRegion();
if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter())))
return failure();
Expand All @@ -923,7 +936,7 @@ struct BeginUntilOpConversion
Value beforeSP = newBeforeBlock->getArgument(0);
rewriter.mergeBlocks(&beforeBlock, newBeforeBlock, {memref, beforeSP});

// --- After region (identity): just yield the SP ---
// After region (identity): just yield the SP.
Block *afterBlock = rewriter.createBlock(&whileOp.getAfter());
afterBlock->addArgument(indexType, loc);
Value afterSP = afterBlock->getArgument(0);
Expand All @@ -936,6 +949,65 @@ struct BeginUntilOpConversion
}
};

/// Conversion pattern for forth.begin_while_repeat operation.
/// Creates scf.while with the condition as the `before` region,
/// and the body as the `after` region.
struct BeginWhileRepeatOpConversion
: public OpConversionPattern<forth::BeginWhileRepeatOp> {
BeginWhileRepeatOpConversion(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern<forth::BeginWhileRepeatOp>(typeConverter, context) {
}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::BeginWhileRepeatOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

auto indexType = rewriter.getIndexType();

// Create scf.while with index result, stackPtr as iter arg.
auto whileOp = rewriter.create<scf::WhileOp>(loc, TypeRange{indexType},
ValueRange{stackPtr});

// Before region (condition): convert + inline.
Region &condRegion = op.getConditionRegion();
if (failed(rewriter.convertRegionTypes(&condRegion, *getTypeConverter())))
return failure();

rewriter.inlineRegionBefore(condRegion, whileOp.getBefore(),
whileOp.getBefore().end());

Block &beforeBlock = whileOp.getBefore().front();
Block *newBeforeBlock = rewriter.createBlock(&whileOp.getBefore());
newBeforeBlock->addArgument(indexType, loc);
Value beforeSP = newBeforeBlock->getArgument(0);
rewriter.mergeBlocks(&beforeBlock, newBeforeBlock, {memref, beforeSP});

// After region (body): convert + inline.
Region &bodyRegion = op.getBodyRegion();
if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter())))
return failure();

rewriter.inlineRegionBefore(bodyRegion, whileOp.getAfter(),
whileOp.getAfter().end());

Block &afterBlock = whileOp.getAfter().front();
Block *newAfterBlock = rewriter.createBlock(&whileOp.getAfter());
newAfterBlock->addArgument(indexType, loc);
Value afterSP = newAfterBlock->getArgument(0);
rewriter.mergeBlocks(&afterBlock, newAfterBlock, {memref, afterSP});

// Replace forth.begin_while_repeat with {memref, whileOp result}.
rewriter.replaceOpWithMultiple(op, {{memref, whileOp.getResult(0)}});
return success();
}
};

/// Conversion pattern for forth.do_loop operation.
/// Pops start and limit from the stack, creates scf.for from start to limit.
struct DoLoopOpConversion : public OpConversionPattern<forth::DoLoopOp> {
Expand All @@ -962,7 +1034,7 @@ struct DoLoopOpConversion : public OpConversionPattern<forth::DoLoopOp> {
Value limitI64 = rewriter.create<memref::LoadOp>(loc, memref, spAfterStart);
Value spAfterPops = rewriter.create<arith::SubIOp>(loc, spAfterStart, one);

// Cast i64 index for scf.for bounds
// Cast i64 to index for scf.for bounds
Value startIdx =
rewriter.create<arith::IndexCastOp>(loc, indexType, startI64);
Value limitIdx =
Expand Down Expand Up @@ -1023,7 +1095,7 @@ struct LoopIndexOpConversion : public OpConversionPattern<forth::LoopIndexOp> {
if (!forOp)
return rewriter.notifyMatchFailure(op, "not inside an scf.for");

// Get induction variable and cast index i64
// Get induction variable and cast index to i64
Value iv = forOp.getInductionVar();
Value ivI64 =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(), iv);
Expand Down Expand Up @@ -1090,10 +1162,12 @@ struct ConvertForthToMemRefPass
AddOpConversion, SubOpConversion, MulOpConversion, DivOpConversion,
ModOpConversion, AndOpConversion, OrOpConversion, XorOpConversion,
NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion,
LtOpConversion, GtOpConversion, ZeroEqOpConversion,
ParamRefOpConversion, LoadOpConversion, StoreOpConversion,
IfOpConversion, BeginUntilOpConversion, DoLoopOpConversion,
LoopIndexOpConversion, YieldOpConversion>(typeConverter, context);
LtOpConversion, GtOpConversion, NeOpConversion, LeOpConversion,
GeOpConversion, ZeroEqOpConversion, ParamRefOpConversion,
LoadOpConversion, StoreOpConversion, IfOpConversion,
BeginUntilOpConversion, BeginWhileRepeatOpConversion,
DoLoopOpConversion, LoopIndexOpConversion, YieldOpConversion>(
typeConverter, context);

// Add GPU indexing op conversion patterns
patterns.add<IntrinsicOpConversion<forth::ThreadIdXOp>>(typeConverter,
Expand Down
70 changes: 70 additions & 0 deletions lib/Dialect/Forth/ForthDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,76 @@ ParseResult BeginUntilOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

//===----------------------------------------------------------------------===//
// BeginWhileRepeatOp RegionBranchOpInterface.
//===----------------------------------------------------------------------===//

void BeginWhileRepeatOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
// From parent: enter the condition region.
regions.push_back(RegionSuccessor(&getConditionRegion(),
getConditionRegion().getArguments()));
return;
}
if (point.getRegionOrNull() == &getConditionRegion()) {
// From condition: enter body or exit to parent.
regions.push_back(
RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
regions.push_back(RegionSuccessor(getOperation()->getResults()));
return;
}
// From body: loop back to condition.
regions.push_back(RegionSuccessor(&getConditionRegion(),
getConditionRegion().getArguments()));
}

OperandRange
BeginWhileRepeatOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getOperation()->getOperands();
}

//===----------------------------------------------------------------------===//
// BeginWhileRepeatOp custom assembly format.
//===----------------------------------------------------------------------===//

void BeginWhileRepeatOp::print(OpAsmPrinter &p) {
p << ' ' << getInputStack() << " : " << getInputStack().getType() << " -> "
<< getOutputStack().getType() << ' ';
p.printRegion(getConditionRegion());
p << " do ";
p.printRegion(getBodyRegion());
}

ParseResult BeginWhileRepeatOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand inputStack;
Type inputType, outputType;

if (parser.parseOperand(inputStack) || parser.parseColon() ||
parser.parseType(inputType) || parser.parseArrow() ||
parser.parseType(outputType) ||
parser.resolveOperand(inputStack, inputType, result.operands))
return failure();

result.addTypes(outputType);

// Parse condition region.
auto *condRegion = result.addRegion();
if (parser.parseRegion(*condRegion))
return failure();

// Parse "do" keyword and body region.
if (parser.parseKeyword("do"))
return failure();

auto *bodyRegion = result.addRegion();
if (parser.parseRegion(*bodyRegion))
return failure();

return success();
}

//===----------------------------------------------------------------------===//
// DoLoopOp RegionBranchOpInterface.
//===----------------------------------------------------------------------===//
Expand Down
Loading