diff --git a/CLAUDE.md b/CLAUDE.md index dae0bd4..6703e93 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -87,7 +87,7 @@ uv run ruff format gpu_test/ - **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety - **Operations**: All take stack as input and produce stack as output (except `forth.stack`) -- **Supported Words**: literals, `DUP DROP SWAP OVER ROT`, `+ - * / MOD`, `= < > 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). +- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `DO LOOP I`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing). - **Kernel Parameters**: Declared with `PARAM `, each becomes a `memref` function argument with `forth.param_name` attribute. Using a param name in code pushes its byte address onto the stack via `forth.param_ref` - **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer - **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index 11bac48..f4b8b4a 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -103,6 +103,67 @@ def Forth_RotOp : Forth_Op<"rot", [Pure]> { }]; } +def Forth_NipOp : Forth_Op<"nip", [Pure]> { + let summary = "Remove second stack element"; + let description = [{ + Removes the second element from the stack, keeping the top. + Forth semantics: ( a b -- b ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + +def Forth_TuckOp : Forth_Op<"tuck", [Pure]> { + let summary = "Copy top element before second"; + let description = [{ + Copies the top element and inserts it before the second element. + Forth semantics: ( a b -- b a b ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + +def Forth_PickOp : Forth_Op<"pick", [Pure]> { + let summary = "Copy nth element to top"; + let description = [{ + Pops n from the stack, then copies the nth element to the top. + Forth semantics: ( xn ... x0 n -- xn ... x0 xn ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + +def Forth_RollOp : Forth_Op<"roll", [Pure]> { + let summary = "Rotate nth element to top"; + let description = [{ + Pops n from the stack, then rotates the nth element to the top, + shifting elements above it down. + Forth semantics: ( xn ... x0 n -- xn-1 ... x0 xn ) + }]; + + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack); + + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + }]; +} + //===----------------------------------------------------------------------===// // Literal operations. //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index d2ca260..6539ecc 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -255,6 +255,160 @@ struct RotOpConversion : public OpConversionPattern { } }; +/// Conversion pattern for forth.nip operation. +/// Removes the second element: (a b -- b) +struct NipOpConversion : public OpConversionPattern { + NipOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::NipOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + // Load top value (b at SP) + Value b = rewriter.create(loc, memref, stackPtr); + + // Store b at SP-1 (overwriting a) + Value one = rewriter.create(loc, 1); + Value spMinus1 = rewriter.create(loc, stackPtr, one); + rewriter.create(loc, b, memref, spMinus1); + + // Net effect: SP-1 (removed one element) + rewriter.replaceOpWithMultiple(op, {{memref, spMinus1}}); + return success(); + } +}; + +/// Conversion pattern for forth.tuck operation. +/// Copies top before second: (a b -- b a b) +struct TuckOpConversion : public OpConversionPattern { + TuckOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::TuckOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + Value one = rewriter.create(loc, 1); + + // Load top two values + Value b = rewriter.create(loc, memref, stackPtr); + Value spMinus1 = rewriter.create(loc, stackPtr, one); + Value a = rewriter.create(loc, memref, spMinus1); + + // Store: b at SP-1, a at SP, b at SP+1 + rewriter.create(loc, b, memref, spMinus1); + rewriter.create(loc, a, memref, stackPtr); + Value newSP = rewriter.create(loc, stackPtr, one); + rewriter.create(loc, b, memref, newSP); + + // Net effect: SP+1 (added one element) + rewriter.replaceOpWithMultiple(op, {{memref, newSP}}); + return success(); + } +}; + +/// Conversion pattern for forth.pick operation. +/// Copies nth element to top: ( xn ... x0 n -- xn ... x0 xn ) +struct PickOpConversion : public OpConversionPattern { + PickOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::PickOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + // Pop n from stack + Value nI64 = rewriter.create(loc, memref, stackPtr); + Value one = rewriter.create(loc, 1); + Value spAfterPop = rewriter.create(loc, stackPtr, one); + + // Cast n to index + Value nIdx = + rewriter.create(loc, rewriter.getIndexType(), nI64); + + // Compute target address: SP' - n + Value targetAddr = rewriter.create(loc, spAfterPop, nIdx); + + // Load the picked value + Value pickedValue = + rewriter.create(loc, memref, targetAddr); + + // Store at SP (where n was), effectively pushing the picked value + rewriter.create(loc, pickedValue, memref, stackPtr); + + // Net effect: SP unchanged (popped n, pushed xn) + rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}}); + return success(); + } +}; + +/// Conversion pattern for forth.roll operation. +/// Rotates nth element to top: ( xn ... x0 n -- xn-1 ... x0 xn ) +struct RollOpConversion : public OpConversionPattern { + RollOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(forth::RollOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ValueRange inputStack = adaptor.getOperands()[0]; + Value memref = inputStack[0]; + Value stackPtr = inputStack[1]; + + // Pop n from stack + Value nI64 = rewriter.create(loc, memref, stackPtr); + Value one = rewriter.create(loc, 1); + Value spAfterPop = rewriter.create(loc, stackPtr, one); + + // Cast n to index + Value nIdx = + rewriter.create(loc, rewriter.getIndexType(), nI64); + + // Compute address of the element to roll: SP' - n + Value rolledAddr = rewriter.create(loc, spAfterPop, nIdx); + + // Save the value to be rolled to top + Value rolledValue = + rewriter.create(loc, memref, rolledAddr); + + // Shift elements down: for i in [rolledAddr, SP') : memref[i] = memref[i+1] + auto forOp = rewriter.create(loc, rolledAddr, spAfterPop, one); + + // Insert ops at start of the auto-created body, before the yield + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value iPlusOne = rewriter.create(loc, iv, one); + Value shiftedVal = rewriter.create(loc, memref, iPlusOne); + rewriter.create(loc, shiftedVal, memref, iv); + + // Store saved value at top (SP') + rewriter.setInsertionPointAfter(forOp); + rewriter.create(loc, rolledValue, memref, spAfterPop); + + // Net effect: SP' = SP - 1 (consumed n) + rewriter.replaceOpWithMultiple(op, {{memref, spAfterPop}}); + return success(); + } +}; + /// Base template for binary arithmetic operations. /// Pops two values, applies operation, pushes result: (a b -- result) template @@ -929,17 +1083,17 @@ struct ConvertForthToMemRefPass RewritePatternSet patterns(context); // Add Forth operation conversion patterns - patterns - .add( - typeConverter, context); + patterns.add< + StackOpConversion, LiteralOpConversion, DupOpConversion, + DropOpConversion, SwapOpConversion, OverOpConversion, RotOpConversion, + NipOpConversion, TuckOpConversion, PickOpConversion, RollOpConversion, + AddOpConversion, SubOpConversion, MulOpConversion, DivOpConversion, + ModOpConversion, AndOpConversion, OrOpConversion, XorOpConversion, + NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion, + LtOpConversion, GtOpConversion, ZeroEqOpConversion, + ParamRefOpConversion, LoadOpConversion, StoreOpConversion, + IfOpConversion, BeginUntilOpConversion, DoLoopOpConversion, + LoopIndexOpConversion, YieldOpConversion>(typeConverter, context); // Add GPU indexing op conversion patterns patterns.add>(typeConverter, diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index 8a56c74..ef012cf 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -261,6 +261,17 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, .getResult(); } else if (word == "ROT") { return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "NIP") { + return builder.create(loc, stackType, inputStack).getResult(); + } else if (word == "TUCK") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "PICK") { + return builder.create(loc, stackType, inputStack) + .getResult(); + } else if (word == "ROLL") { + return builder.create(loc, stackType, inputStack) + .getResult(); } else if (word == "+" || word == "ADD") { return builder.create(loc, stackType, inputStack).getResult(); } else if (word == "-" || word == "SUB") { diff --git a/test/Conversion/ForthToMemRef/stack-manipulation.mlir b/test/Conversion/ForthToMemRef/stack-manipulation.mlir index f2d1ada..bd69acd 100644 --- a/test/Conversion/ForthToMemRef/stack-manipulation.mlir +++ b/test/Conversion/ForthToMemRef/stack-manipulation.mlir @@ -33,6 +33,40 @@ // CHECK: memref.store %[[ROT_C]], %{{.*}}[%[[ROT_SP1]]] : memref<256xi64> // CHECK: memref.store %[[ROT_A]], %{{.*}}[%[[ROT_SP]]] : memref<256xi64> +// nip: load top, subi SP, store at SP-1 (a b -- b) +// CHECK: %[[NIP_B:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64> +// CHECK: %[[NIP_SP1:.*]] = arith.subi +// CHECK: memref.store %[[NIP_B]], %{{.*}}[%[[NIP_SP1]]] : memref<256xi64> + +// tuck: load b and a, store b/a/b (a b -- b a b) +// CHECK: %[[TUCK_B:.*]] = memref.load %{{.*}}[%[[TUCK_SP:.*]]] : memref<256xi64> +// CHECK: %[[TUCK_SP1:.*]] = arith.subi +// CHECK: %[[TUCK_A:.*]] = memref.load %{{.*}}[%[[TUCK_SP1]]] : memref<256xi64> +// CHECK: memref.store %[[TUCK_B]], %{{.*}}[%[[TUCK_SP1]]] : memref<256xi64> +// CHECK: memref.store %[[TUCK_A]], %{{.*}}[%[[TUCK_SP]]] : memref<256xi64> +// CHECK: %[[TUCK_NSP:.*]] = arith.addi +// CHECK: memref.store %[[TUCK_B]], %{{.*}}[%[[TUCK_NSP]]] : memref<256xi64> + +// pick: load n, index_cast, subi (dynamic), load picked, store +// CHECK: %[[PICK_N:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64> +// CHECK: %[[PICK_SP1:.*]] = arith.subi +// CHECK: %[[PICK_NIDX:.*]] = arith.index_cast %[[PICK_N]] +// CHECK: %[[PICK_ADDR:.*]] = arith.subi %[[PICK_SP1]], %[[PICK_NIDX]] +// CHECK: %[[PICK_VAL:.*]] = memref.load %{{.*}}[%[[PICK_ADDR]]] : memref<256xi64> +// CHECK: memref.store %[[PICK_VAL]] + +// roll: load n, index_cast, subi (dynamic), load saved, scf.for with load/store, store saved +// CHECK: %[[ROLL_N:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64> +// CHECK: %[[ROLL_SP1:.*]] = arith.subi +// CHECK: %[[ROLL_NIDX:.*]] = arith.index_cast %[[ROLL_N]] +// CHECK: %[[ROLL_ADDR:.*]] = arith.subi %[[ROLL_SP1]], %[[ROLL_NIDX]] +// CHECK: %[[ROLL_SAVED:.*]] = memref.load %{{.*}}[%[[ROLL_ADDR]]] : memref<256xi64> +// CHECK: scf.for %[[ROLL_IV:.*]] = %[[ROLL_ADDR]] to %[[ROLL_SP1]] +// CHECK: %[[ROLL_NEXT:.*]] = arith.addi %[[ROLL_IV]] +// CHECK: %[[ROLL_SHIFTED:.*]] = memref.load %{{.*}}[%[[ROLL_NEXT]]] : memref<256xi64> +// CHECK: memref.store %[[ROLL_SHIFTED]], %{{.*}}[%[[ROLL_IV]]] : memref<256xi64> +// CHECK: memref.store %[[ROLL_SAVED]], %{{.*}}[%[[ROLL_SP1]]] : memref<256xi64> + module { func.func private @main() { %0 = forth.stack !forth.stack @@ -44,6 +78,12 @@ module { %6 = forth.swap %5 : !forth.stack -> !forth.stack %7 = forth.over %6 : !forth.stack -> !forth.stack %8 = forth.rot %7 : !forth.stack -> !forth.stack + %9 = forth.nip %8 : !forth.stack -> !forth.stack + %10 = forth.tuck %9 : !forth.stack -> !forth.stack + %11 = forth.literal %10 2 : !forth.stack -> !forth.stack + %12 = forth.pick %11 : !forth.stack -> !forth.stack + %13 = forth.literal %12 2 : !forth.stack -> !forth.stack + %14 = forth.roll %13 : !forth.stack -> !forth.stack return } } diff --git a/test/Translation/Forth/stack-ops.forth b/test/Translation/Forth/stack-ops.forth index 9bfa03f..01cccf6 100644 --- a/test/Translation/Forth/stack-ops.forth +++ b/test/Translation/Forth/stack-ops.forth @@ -6,4 +6,8 @@ \ CHECK: forth.swap %{{.*}} : !forth.stack -> !forth.stack \ CHECK: forth.over %{{.*}} : !forth.stack -> !forth.stack \ CHECK: forth.rot %{{.*}} : !forth.stack -> !forth.stack -1 DUP DROP SWAP OVER ROT +\ CHECK: forth.nip %{{.*}} : !forth.stack -> !forth.stack +\ CHECK: forth.tuck %{{.*}} : !forth.stack -> !forth.stack +\ CHECK: forth.pick %{{.*}} : !forth.stack -> !forth.stack +\ CHECK: forth.roll %{{.*}} : !forth.stack -> !forth.stack +1 DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL