Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT`, `+ - * / MOD`, `= < > 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Kernel Parameters**: Declared with `PARAM <name> <size>`, each becomes a `memref<Nxi64>` function argument with `forth.param_name` attribute. Using a param name in code pushes its byte address onto the stack via `forth.param_ref`
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
Expand Down
61 changes: 61 additions & 0 deletions include/warpforth/Dialect/Forth/ForthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,67 @@ def Forth_RotOp : Forth_Op<"rot", [Pure]> {
}];
}

def Forth_NipOp : Forth_Op<"nip", [Pure]> {
let summary = "Remove second stack element";
let description = [{
Removes the second element from the stack, keeping the top.
Forth semantics: ( a b -- b )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

def Forth_TuckOp : Forth_Op<"tuck", [Pure]> {
let summary = "Copy top element before second";
let description = [{
Copies the top element and inserts it before the second element.
Forth semantics: ( a b -- b a b )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

def Forth_PickOp : Forth_Op<"pick", [Pure]> {
let summary = "Copy nth element to top";
let description = [{
Pops n from the stack, then copies the nth element to the top.
Forth semantics: ( xn ... x0 n -- xn ... x0 xn )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

def Forth_RollOp : Forth_Op<"roll", [Pure]> {
let summary = "Rotate nth element to top";
let description = [{
Pops n from the stack, then rotates the nth element to the top,
shifting elements above it down.
Forth semantics: ( xn ... x0 n -- xn-1 ... x0 xn )
}];

let arguments = (ins Forth_StackType:$input_stack);
let results = (outs Forth_StackType:$output_stack);

let assemblyFormat = [{
$input_stack attr-dict `:` type($input_stack) `->` type($output_stack)
}];
}

//===----------------------------------------------------------------------===//
// Literal operations.
//===----------------------------------------------------------------------===//
Expand Down
176 changes: 165 additions & 11 deletions lib/Conversion/ForthToMemRef/ForthToMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,160 @@ struct RotOpConversion : public OpConversionPattern<forth::RotOp> {
}
};

/// Conversion pattern for forth.nip operation.
/// Removes the second element: (a b -- b)
struct NipOpConversion : public OpConversionPattern<forth::NipOp> {
NipOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<forth::NipOp>(typeConverter, context) {}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::NipOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

// Load top value (b at SP)
Value b = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);

// Store b at SP-1 (overwriting a)
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value spMinus1 = rewriter.create<arith::SubIOp>(loc, stackPtr, one);
rewriter.create<memref::StoreOp>(loc, b, memref, spMinus1);

// Net effect: SP-1 (removed one element)
rewriter.replaceOpWithMultiple(op, {{memref, spMinus1}});
return success();
}
};

/// Conversion pattern for forth.tuck operation.
/// Copies top before second: (a b -- b a b)
struct TuckOpConversion : public OpConversionPattern<forth::TuckOp> {
TuckOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<forth::TuckOp>(typeConverter, context) {}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::TuckOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);

// Load top two values
Value b = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);
Value spMinus1 = rewriter.create<arith::SubIOp>(loc, stackPtr, one);
Value a = rewriter.create<memref::LoadOp>(loc, memref, spMinus1);

// Store: b at SP-1, a at SP, b at SP+1
rewriter.create<memref::StoreOp>(loc, b, memref, spMinus1);
rewriter.create<memref::StoreOp>(loc, a, memref, stackPtr);
Value newSP = rewriter.create<arith::AddIOp>(loc, stackPtr, one);
rewriter.create<memref::StoreOp>(loc, b, memref, newSP);

// Net effect: SP+1 (added one element)
rewriter.replaceOpWithMultiple(op, {{memref, newSP}});
return success();
}
};

/// Conversion pattern for forth.pick operation.
/// Copies nth element to top: ( xn ... x0 n -- xn ... x0 xn )
struct PickOpConversion : public OpConversionPattern<forth::PickOp> {
PickOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<forth::PickOp>(typeConverter, context) {}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::PickOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

// Pop n from stack
Value nI64 = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value spAfterPop = rewriter.create<arith::SubIOp>(loc, stackPtr, one);

// Cast n to index
Value nIdx =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nI64);

// Compute target address: SP' - n
Value targetAddr = rewriter.create<arith::SubIOp>(loc, spAfterPop, nIdx);

// Load the picked value
Value pickedValue =
rewriter.create<memref::LoadOp>(loc, memref, targetAddr);

// Store at SP (where n was), effectively pushing the picked value
rewriter.create<memref::StoreOp>(loc, pickedValue, memref, stackPtr);

// Net effect: SP unchanged (popped n, pushed xn)
rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}});
return success();
}
};

/// Conversion pattern for forth.roll operation.
/// Rotates nth element to top: ( xn ... x0 n -- xn-1 ... x0 xn )
struct RollOpConversion : public OpConversionPattern<forth::RollOp> {
RollOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<forth::RollOp>(typeConverter, context) {}
using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(forth::RollOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

// Pop n from stack
Value nI64 = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value spAfterPop = rewriter.create<arith::SubIOp>(loc, stackPtr, one);

// Cast n to index
Value nIdx =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nI64);

// Compute address of the element to roll: SP' - n
Value rolledAddr = rewriter.create<arith::SubIOp>(loc, spAfterPop, nIdx);

// Save the value to be rolled to top
Value rolledValue =
rewriter.create<memref::LoadOp>(loc, memref, rolledAddr);

// Shift elements down: for i in [rolledAddr, SP') : memref[i] = memref[i+1]
auto forOp = rewriter.create<scf::ForOp>(loc, rolledAddr, spAfterPop, one);

// Insert ops at start of the auto-created body, before the yield
rewriter.setInsertionPointToStart(forOp.getBody());
Value iv = forOp.getInductionVar();
Value iPlusOne = rewriter.create<arith::AddIOp>(loc, iv, one);
Value shiftedVal = rewriter.create<memref::LoadOp>(loc, memref, iPlusOne);
rewriter.create<memref::StoreOp>(loc, shiftedVal, memref, iv);

// Store saved value at top (SP')
rewriter.setInsertionPointAfter(forOp);
rewriter.create<memref::StoreOp>(loc, rolledValue, memref, spAfterPop);

// Net effect: SP' = SP - 1 (consumed n)
rewriter.replaceOpWithMultiple(op, {{memref, spAfterPop}});
return success();
}
};

/// Base template for binary arithmetic operations.
/// Pops two values, applies operation, pushes result: (a b -- result)
template <typename ForthOp, typename ArithOp>
Expand Down Expand Up @@ -929,17 +1083,17 @@ struct ConvertForthToMemRefPass
RewritePatternSet patterns(context);

// Add Forth operation conversion patterns
patterns
.add<StackOpConversion, LiteralOpConversion, DupOpConversion,
DropOpConversion, SwapOpConversion, OverOpConversion,
RotOpConversion, AddOpConversion, SubOpConversion, MulOpConversion,
DivOpConversion, ModOpConversion, AndOpConversion, OrOpConversion,
XorOpConversion, NotOpConversion, LshiftOpConversion,
RshiftOpConversion, EqOpConversion, LtOpConversion, GtOpConversion,
ZeroEqOpConversion, ParamRefOpConversion, LoadOpConversion,
StoreOpConversion, IfOpConversion, BeginUntilOpConversion,
DoLoopOpConversion, LoopIndexOpConversion, YieldOpConversion>(
typeConverter, context);
patterns.add<
StackOpConversion, LiteralOpConversion, DupOpConversion,
DropOpConversion, SwapOpConversion, OverOpConversion, RotOpConversion,
NipOpConversion, TuckOpConversion, PickOpConversion, RollOpConversion,
AddOpConversion, SubOpConversion, MulOpConversion, DivOpConversion,
ModOpConversion, AndOpConversion, OrOpConversion, XorOpConversion,
NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion,
LtOpConversion, GtOpConversion, ZeroEqOpConversion,
ParamRefOpConversion, LoadOpConversion, StoreOpConversion,
IfOpConversion, BeginUntilOpConversion, DoLoopOpConversion,
LoopIndexOpConversion, YieldOpConversion>(typeConverter, context);

// Add GPU indexing op conversion patterns
patterns.add<IntrinsicOpConversion<forth::ThreadIdXOp>>(typeConverter,
Expand Down
11 changes: 11 additions & 0 deletions lib/Translation/ForthToMLIR/ForthToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,17 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack,
.getResult();
} else if (word == "ROT") {
return builder.create<forth::RotOp>(loc, stackType, inputStack).getResult();
} else if (word == "NIP") {
return builder.create<forth::NipOp>(loc, stackType, inputStack).getResult();
} else if (word == "TUCK") {
return builder.create<forth::TuckOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "PICK") {
return builder.create<forth::PickOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "ROLL") {
return builder.create<forth::RollOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "+" || word == "ADD") {
return builder.create<forth::AddOp>(loc, stackType, inputStack).getResult();
} else if (word == "-" || word == "SUB") {
Expand Down
40 changes: 40 additions & 0 deletions test/Conversion/ForthToMemRef/stack-manipulation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,40 @@
// CHECK: memref.store %[[ROT_C]], %{{.*}}[%[[ROT_SP1]]] : memref<256xi64>
// CHECK: memref.store %[[ROT_A]], %{{.*}}[%[[ROT_SP]]] : memref<256xi64>

// nip: load top, subi SP, store at SP-1 (a b -- b)
// CHECK: %[[NIP_B:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64>
// CHECK: %[[NIP_SP1:.*]] = arith.subi
// CHECK: memref.store %[[NIP_B]], %{{.*}}[%[[NIP_SP1]]] : memref<256xi64>

// tuck: load b and a, store b/a/b (a b -- b a b)
// CHECK: %[[TUCK_B:.*]] = memref.load %{{.*}}[%[[TUCK_SP:.*]]] : memref<256xi64>
// CHECK: %[[TUCK_SP1:.*]] = arith.subi
// CHECK: %[[TUCK_A:.*]] = memref.load %{{.*}}[%[[TUCK_SP1]]] : memref<256xi64>
// CHECK: memref.store %[[TUCK_B]], %{{.*}}[%[[TUCK_SP1]]] : memref<256xi64>
// CHECK: memref.store %[[TUCK_A]], %{{.*}}[%[[TUCK_SP]]] : memref<256xi64>
// CHECK: %[[TUCK_NSP:.*]] = arith.addi
// CHECK: memref.store %[[TUCK_B]], %{{.*}}[%[[TUCK_NSP]]] : memref<256xi64>

// pick: load n, index_cast, subi (dynamic), load picked, store
// CHECK: %[[PICK_N:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64>
// CHECK: %[[PICK_SP1:.*]] = arith.subi
// CHECK: %[[PICK_NIDX:.*]] = arith.index_cast %[[PICK_N]]
// CHECK: %[[PICK_ADDR:.*]] = arith.subi %[[PICK_SP1]], %[[PICK_NIDX]]
// CHECK: %[[PICK_VAL:.*]] = memref.load %{{.*}}[%[[PICK_ADDR]]] : memref<256xi64>
// CHECK: memref.store %[[PICK_VAL]]

// roll: load n, index_cast, subi (dynamic), load saved, scf.for with load/store, store saved
// CHECK: %[[ROLL_N:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64>
// CHECK: %[[ROLL_SP1:.*]] = arith.subi
// CHECK: %[[ROLL_NIDX:.*]] = arith.index_cast %[[ROLL_N]]
// CHECK: %[[ROLL_ADDR:.*]] = arith.subi %[[ROLL_SP1]], %[[ROLL_NIDX]]
// CHECK: %[[ROLL_SAVED:.*]] = memref.load %{{.*}}[%[[ROLL_ADDR]]] : memref<256xi64>
// CHECK: scf.for %[[ROLL_IV:.*]] = %[[ROLL_ADDR]] to %[[ROLL_SP1]]
// CHECK: %[[ROLL_NEXT:.*]] = arith.addi %[[ROLL_IV]]
// CHECK: %[[ROLL_SHIFTED:.*]] = memref.load %{{.*}}[%[[ROLL_NEXT]]] : memref<256xi64>
// CHECK: memref.store %[[ROLL_SHIFTED]], %{{.*}}[%[[ROLL_IV]]] : memref<256xi64>
// CHECK: memref.store %[[ROLL_SAVED]], %{{.*}}[%[[ROLL_SP1]]] : memref<256xi64>

module {
func.func private @main() {
%0 = forth.stack !forth.stack
Expand All @@ -44,6 +78,12 @@ module {
%6 = forth.swap %5 : !forth.stack -> !forth.stack
%7 = forth.over %6 : !forth.stack -> !forth.stack
%8 = forth.rot %7 : !forth.stack -> !forth.stack
%9 = forth.nip %8 : !forth.stack -> !forth.stack
%10 = forth.tuck %9 : !forth.stack -> !forth.stack
%11 = forth.literal %10 2 : !forth.stack -> !forth.stack
%12 = forth.pick %11 : !forth.stack -> !forth.stack
%13 = forth.literal %12 2 : !forth.stack -> !forth.stack
%14 = forth.roll %13 : !forth.stack -> !forth.stack
return
}
}
6 changes: 5 additions & 1 deletion test/Translation/Forth/stack-ops.forth
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
\ CHECK: forth.swap %{{.*}} : !forth.stack -> !forth.stack
\ CHECK: forth.over %{{.*}} : !forth.stack -> !forth.stack
\ CHECK: forth.rot %{{.*}} : !forth.stack -> !forth.stack
1 DUP DROP SWAP OVER ROT
\ CHECK: forth.nip %{{.*}} : !forth.stack -> !forth.stack
\ CHECK: forth.tuck %{{.*}} : !forth.stack -> !forth.stack
\ CHECK: forth.pick %{{.*}} : !forth.stack -> !forth.stack
\ CHECK: forth.roll %{{.*}} : !forth.stack -> !forth.stack
1 DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL