diff --git a/src/autocomplete/content-assist.ts b/src/autocomplete/content-assist.ts index ecc3512..2026003 100644 --- a/src/autocomplete/content-assist.ts +++ b/src/autocomplete/content-assist.ts @@ -76,6 +76,10 @@ export interface ContentAssistResult { * should resolve this against tablesInScope aliases/names to filter columns. */ qualifiedTableRef?: string + /** Whether the grammar context expects column names (expression/columnRef positions) */ + suggestColumns: boolean + /** Whether the grammar context expects table names (tableName positions, or expression context) */ + suggestTables: boolean } // ============================================================================= @@ -239,6 +243,23 @@ function extractTablesFromAst(ast: unknown): ExtractResult { } } + // Handle ALTER TABLE / ALTER MATERIALIZED VIEW + if ( + (n.type === "alterTable" || n.type === "alterMaterializedView") && + n.table + ) { + const tableName = normalizeTableName(n.table) + if (tableName) { + tables.push({ table: tableName }) + } + } + if (n.type === "alterMaterializedView" && n.view) { + const viewName = normalizeTableName(n.view) + if (viewName) { + tables.push({ table: viewName }) + } + } + // Recurse into child nodes for (const key of Object.keys(n)) { const child = n[key] @@ -511,28 +532,46 @@ function extractTables(fullSql: string, tokens: IToken[]): ExtractResult { // Scan for FROM/JOIN table references only in the outer query (after CTEs). // This avoids leaking tables referenced inside CTE bodies into the outer scope. + // Also detect ALTER TABLE / TRUNCATE TABLE patterns for column scoping. + const DDL_TABLE_PREFIXES = new Set(["Alter", "Truncate", "Drop"]) for (let i = outerQueryStart; i < tokens.length; i++) { - if (!TABLE_PREFIX_TOKENS.has(tokens[i].tokenType.name)) continue + const tokenName = tokens[i].tokenType.name - const tableNameResult = readQualifiedName(i + 1) - if (!tableNameResult) continue + // Standard DML: FROM/JOIN/UPDATE/INTO + if (TABLE_PREFIX_TOKENS.has(tokenName)) { + const tableNameResult = readQualifiedName(i + 1) + if (!tableNameResult) continue - let alias: string | undefined - let aliasStart = tableNameResult.nextIndex - if (tokens[aliasStart]?.tokenType.name === "As") { - aliasStart++ - } - if (isIdentifierLike(tokens[aliasStart])) { - alias = tokenToNamePart(tokens[aliasStart]) - } + let alias: string | undefined + let aliasStart = tableNameResult.nextIndex + if (tokens[aliasStart]?.tokenType.name === "As") { + aliasStart++ + } + if (isIdentifierLike(tokens[aliasStart])) { + alias = tokenToNamePart(tokens[aliasStart]) + } - tables.push({ - table: tableNameResult.name, - alias, - }) + tables.push({ + table: tableNameResult.name, + alias, + }) + + // Continue from where we consumed table/alias to avoid duplicate captures. + i = alias ? aliasStart : tableNameResult.nextIndex - 1 + continue + } - // Continue from where we consumed table/alias to avoid duplicate captures. - i = alias ? aliasStart : tableNameResult.nextIndex - 1 + // DDL: ALTER TABLE / TRUNCATE TABLE / DROP TABLE + if ( + DDL_TABLE_PREFIXES.has(tokenName) && + tokens[i + 1]?.tokenType.name === "Table" + ) { + const tableNameResult = readQualifiedName(i + 2) + if (tableNameResult) { + tables.push({ table: tableNameResult.name }) + i = tableNameResult.nextIndex - 1 + } + } } for (const name of cteNames) { tables.push({ table: name }) @@ -615,6 +654,32 @@ function collapseTrailingQualifiedRef(tokens: IToken[]): IToken[] | null { return [...tokens.slice(0, start), lastToken] } +/** + * Classify an identifier suggestion path based on its ruleStack. + * - "column": identifierExpression or columnRef → suggest columns + tables + * - "table": tableName rule → suggest tables only + * - "newName": everything else (CREATE TABLE name, user names, etc.) → no suggestions + */ +function classifyIdentifierPath( + ruleStack: string[], +): "column" | "table" | "newName" { + if (ruleStack.includes("valuesClause")) return "newName" + if ( + ruleStack.includes("identifierExpression") || + ruleStack.includes("columnRef") || + ruleStack.includes("qualifiedStar") + ) + return "column" + if (ruleStack.includes("tableName")) return "table" + return "newName" +} + +interface ComputeResult { + nextTokenTypes: TokenType[] + suggestColumns: boolean + suggestTables: boolean +} + /** * Compute content assist suggestions, handling CTE context specially. * @@ -623,7 +688,7 @@ function collapseTrailingQualifiedRef(tokens: IToken[]): IToken[] | null { * updateStatement paths. This function detects that case and merges suggestions * from all WITH-capable statement types. */ -function computeSuggestions(tokens: IToken[]): TokenType[] { +function computeSuggestions(tokens: IToken[]): ComputeResult { const ruleName = tokens.some((t) => t.tokenType.name === "Semicolon") ? "statements" : "statement" @@ -638,22 +703,34 @@ function computeSuggestions(tokens: IToken[]): TokenType[] { const specific = suggestions.filter( (s) => !isImplicitStatementPath(s.ruleStack, IMPLICIT_RULES), ) - const result = (specific.length > 0 ? specific : suggestions).map( - (s) => s.nextTokenType, - ) + const effectiveSuggestions = specific.length > 0 ? specific : suggestions + const result = effectiveSuggestions.map((s) => s.nextTokenType) + + // Classify each IdentifierKeyword path to determine whether columns/tables + // should be suggested, based on the grammar rule that expects the identifier. + let suggestColumns = false + let suggestTables = false + for (const s of effectiveSuggestions) { + if (s.nextTokenType.name === "IdentifierKeyword") { + const cls = classifyIdentifierPath(s.ruleStack) + if (cls === "column") { + suggestColumns = true + suggestTables = true + } else if (cls === "table") { + suggestTables = true + } + } + } // qualifiedStar fix: When computeContentAssist finds the qualifiedStar // path in selectItem (suggesting just Dot), the expression path is missed. // Detect this by checking if the *specific* (non-catch-all) suggestions are // all from qualifiedStar, then re-compute with the qualified reference // collapsed to a single identifier to get expression-path suggestions. - const effectiveSuggestions = specific.length > 0 ? specific : suggestions if ( effectiveSuggestions.length > 0 && effectiveSuggestions.every((s) => s.ruleStack.includes("qualifiedStar")) ) { - // Find and collapse the trailing qualified reference (ident.ident...ident) - // into a single identifier token, then re-compute to get expression-path suggestions. const collapsed = collapseTrailingQualifiedRef(tokens) if (collapsed) { try { @@ -661,14 +738,23 @@ function computeSuggestions(tokens: IToken[]): TokenType[] { const filteredExtra = extra.filter( (s) => !isImplicitStatementPath(s.ruleStack, IMPLICIT_RULES), ) - const extraResult = ( + const extraEffective = filteredExtra.length > 0 ? filteredExtra : extra - ).map((s) => s.nextTokenType) const seen = new Set(result.map((t) => t.name)) - for (const t of extraResult) { - if (!seen.has(t.name)) { - seen.add(t.name) - result.push(t) + for (const s of extraEffective) { + if (!seen.has(s.nextTokenType.name)) { + seen.add(s.nextTokenType.name) + result.push(s.nextTokenType) + } + // Classify extra paths too + if (s.nextTokenType.name === "IdentifierKeyword") { + const cls = classifyIdentifierPath(s.ruleStack) + if (cls === "column") { + suggestColumns = true + suggestTables = true + } else if (cls === "table") { + suggestTables = true + } } } } catch (e) { @@ -677,7 +763,7 @@ function computeSuggestions(tokens: IToken[]): TokenType[] { } } - return result + return { nextTokenTypes: result, suggestColumns, suggestTables } } /** @@ -738,6 +824,8 @@ export function getContentAssist( tokensBefore: [], isMidWord: true, lexErrors: [], + suggestColumns: false, + suggestTables: false, } } } @@ -764,8 +852,13 @@ export function getContentAssist( // Get syntactically valid next tokens using Chevrotain's content assist let nextTokenTypes: TokenType[] = [] + let suggestColumns = false + let suggestTables = false try { - nextTokenTypes = computeSuggestions(tokensForAssist) + const computed = computeSuggestions(tokensForAssist) + nextTokenTypes = computed.nextTokenTypes + suggestColumns = computed.suggestColumns + suggestTables = computed.suggestTables } catch (e) { // If content assist fails, return empty suggestions // This can happen with malformed input @@ -834,6 +927,8 @@ export function getContentAssist( isMidWord, lexErrors: lexResult.errors, qualifiedTableRef: qualifiedRef?.table, + suggestColumns, + suggestTables, } } @@ -843,7 +938,9 @@ export function getContentAssist( export function getNextValidTokens(sql: string): string[] { const lexResult = QuestDBLexer.tokenize(sql) try { - return computeSuggestions(lexResult.tokens).map((t) => t.name) + return computeSuggestions(lexResult.tokens).nextTokenTypes.map( + (t) => t.name, + ) } catch (e) { return [] } diff --git a/src/autocomplete/provider.ts b/src/autocomplete/provider.ts index 71811b3..7e0e822 100644 --- a/src/autocomplete/provider.ts +++ b/src/autocomplete/provider.ts @@ -21,10 +21,7 @@ import type { IToken } from "chevrotain" import { getContentAssist } from "./content-assist" import { buildSuggestions } from "./suggestion-builder" -import { - shouldSkipToken, - IDENTIFIER_KEYWORD_TOKENS, -} from "./token-classification" +import { shouldSkipToken } from "./token-classification" import type { AutocompleteProvider, SchemaInfo, Suggestion } from "./types" import { SuggestionKind, SuggestionPriority } from "./types" @@ -54,81 +51,6 @@ function getLastSignificantTokens(tokens: IToken[]): string[] { } return result } -/** - * Tokens that signal the end of an expression / value. When these appear as - * the raw last token before the cursor, the cursor is in alias or keyword - * position — NOT column position. e.g., "SELECT symbol |" → alias position. - */ -const EXPRESSION_END_TOKENS = new Set([ - "Identifier", - "QuotedIdentifier", - "RParen", - "NumberLiteral", - "LongLiteral", - "DecimalLiteral", - "StringLiteral", -]) - -function isExpressionEnd(tokenName: string): boolean { - return ( - EXPRESSION_END_TOKENS.has(tokenName) || - IDENTIFIER_KEYWORD_TOKENS.has(tokenName) - ) -} - -function getIdentifierSuggestionScope( - lastTokenName?: string, - prevTokenName?: string, - rawLastTokenName?: string, - rawPrevTokenName?: string, -): { - includeColumns: boolean - includeTables: boolean -} { - // Expression-end tokens indicate alias / post-expression position. - // e.g., "SELECT symbol |" or "FROM trades |" — no columns expected. - if (rawLastTokenName && isExpressionEnd(rawLastTokenName)) { - return { includeColumns: false, includeTables: false } - } - - // Star (*) is context-dependent: it's a wildcard after SELECT/comma/LParen, - // but multiplication after an expression (identifier, number, rparen). - // "SELECT * |" → wildcard, suppress columns (alias/keyword position) - // "SELECT price * |" → multiplication, suggest columns for RHS - if (rawLastTokenName === "Star") { - if (rawPrevTokenName && isExpressionEnd(rawPrevTokenName)) { - // Multiplication: previous token is an expression-end, so * is an operator. - // The user needs columns/functions for the right-hand side. - return { includeColumns: true, includeTables: true } - } - // Wildcard: no expression before *, e.g., SELECT *, t.*, or start of expression - return { includeColumns: false, includeTables: false } - } - - // After AS keyword: either subquery start (WITH name AS (|) or alias (SELECT x AS |). - if (lastTokenName === "As") { - // "WITH name AS (|" → LParen is raw last → subquery start, suggest tables - if (rawLastTokenName === "LParen") { - return { includeColumns: false, includeTables: true } - } - // "SELECT x AS |" → alias position - return { includeColumns: false, includeTables: false } - } - - if (prevTokenName && TABLE_NAME_TOKENS.has(prevTokenName)) { - return { includeColumns: false, includeTables: true } - } - if (lastTokenName && TABLE_NAME_TOKENS.has(lastTokenName)) { - return { includeColumns: false, includeTables: true } - } - // At statement start (no significant tokens before cursor), only suggest - // tables and keywords, not columns. Identifier is valid here only because - // of PIVOT syntax (e.g., "trades PIVOT (...)"), not for column references. - if (!lastTokenName) { - return { includeColumns: false, includeTables: true } - } - return { includeColumns: true, includeTables: true } -} /** * Create an autocomplete provider with the given schema @@ -172,6 +94,8 @@ export function createAutocompleteProvider( tokensBefore, isMidWord, qualifiedTableRef, + suggestColumns, + suggestTables, } = getContentAssist(query, cursorOffset) // Merge CTE columns into the schema so getColumnsInScope() can find them @@ -220,35 +144,17 @@ export function createAutocompleteProvider( } } - // If parser returned valid next tokens, use them + // If parser returned valid next tokens, use grammar-based classification if (nextTokenTypes.length > 0) { - // When mid-word, the last token in tokensBefore is the partial word being typed. - // For scope detection, we need the tokens BEFORE that partial word. - const tokensForScope = - isMidWord && tokensBefore.length > 0 - ? tokensBefore.slice(0, -1) - : tokensBefore - const [lastTokenName, prevTokenName] = - getLastSignificantTokens(tokensForScope) - const rawLastTokenName = - tokensForScope.length > 0 - ? tokensForScope[tokensForScope.length - 1]?.tokenType?.name - : undefined - const rawPrevTokenName = - tokensForScope.length > 1 - ? tokensForScope[tokensForScope.length - 2]?.tokenType?.name - : undefined - const scope = getIdentifierSuggestionScope( - lastTokenName, - prevTokenName, - rawLastTokenName, - rawPrevTokenName, - ) return buildSuggestions( nextTokenTypes, effectiveSchema, effectiveTablesInScope, - { ...scope, isMidWord }, + { + includeColumns: suggestColumns, + includeTables: suggestTables, + isMidWord, + }, ) } diff --git a/src/parser/cst-types.d.ts b/src/parser/cst-types.d.ts index b716cd4..1bca903 100644 --- a/src/parser/cst-types.d.ts +++ b/src/parser/cst-types.d.ts @@ -221,7 +221,7 @@ export type TableRefCstChildren = { tableFunctionCall?: TableFunctionCallCstNode[]; VariableReference?: IToken[]; StringLiteral?: IToken[]; - qualifiedName?: QualifiedNameCstNode[]; + tableName?: TableNameCstNode[]; Timestamp?: IToken[]; columnRef?: ColumnRefCstNode[]; As?: IToken[]; @@ -256,23 +256,53 @@ export interface JoinClauseCstNode extends CstNode { } export type JoinClauseCstChildren = { - Inner?: IToken[]; - Left?: IToken[]; - Right?: IToken[]; - Full?: IToken[]; - Cross?: IToken[]; + asofLtJoin?: AsofLtJoinCstNode[]; + spliceJoin?: SpliceJoinCstNode[]; + windowJoin?: WindowJoinCstNode[]; + standardJoin?: StandardJoinCstNode[]; +}; + +export interface AsofLtJoinCstNode extends CstNode { + name: "asofLtJoin"; + children: AsofLtJoinCstChildren; +} + +export type AsofLtJoinCstChildren = { Asof?: IToken[]; Lt?: IToken[]; - Splice?: IToken[]; - Window?: IToken[]; - Prevailing?: (IToken)[]; - Outer?: IToken[]; Join: IToken[]; tableRef: TableRefCstNode[]; On?: IToken[]; expression?: ExpressionCstNode[]; Tolerance?: IToken[]; DurationLiteral?: IToken[]; +}; + +export interface SpliceJoinCstNode extends CstNode { + name: "spliceJoin"; + children: SpliceJoinCstChildren; +} + +export type SpliceJoinCstChildren = { + Splice: IToken[]; + Join: IToken[]; + tableRef: TableRefCstNode[]; + On?: IToken[]; + expression?: ExpressionCstNode[]; +}; + +export interface WindowJoinCstNode extends CstNode { + name: "windowJoin"; + children: WindowJoinCstChildren; +} + +export type WindowJoinCstChildren = { + Window?: IToken[]; + Prevailing?: (IToken)[]; + Join: IToken[]; + tableRef: TableRefCstNode[]; + On?: IToken[]; + expression?: ExpressionCstNode[]; Range?: IToken[]; Between?: IToken[]; windowJoinBound?: (WindowJoinBoundCstNode)[]; @@ -281,6 +311,24 @@ export type JoinClauseCstChildren = { Exclude?: IToken[]; }; +export interface StandardJoinCstNode extends CstNode { + name: "standardJoin"; + children: StandardJoinCstChildren; +} + +export type StandardJoinCstChildren = { + Left?: IToken[]; + Right?: IToken[]; + Full?: IToken[]; + Outer?: IToken[]; + Inner?: IToken[]; + Cross?: IToken[]; + Join: IToken[]; + tableRef: TableRefCstNode[]; + On?: IToken[]; + expression?: ExpressionCstNode[]; +}; + export interface WindowJoinBoundCstNode extends CstNode { name: "windowJoinBound"; children: WindowJoinBoundCstChildren; @@ -455,7 +503,7 @@ export type InsertStatementCstChildren = { Atomic?: IToken[]; batchClause?: (BatchClauseCstNode)[]; Into: IToken[]; - stringOrQualifiedName: StringOrQualifiedNameCstNode[]; + tableNameOrString: TableNameOrStringCstNode[]; LParen?: IToken[]; identifier?: (IdentifierCstNode)[]; Comma?: IToken[]; @@ -494,7 +542,7 @@ export interface UpdateStatementCstNode extends CstNode { export type UpdateStatementCstChildren = { Update: IToken[]; - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; identifier?: IdentifierCstNode[]; Set: IToken[]; setClause: (SetClauseCstNode)[]; @@ -600,7 +648,7 @@ export type CreateTableBodyCstChildren = { indexDefinition?: (IndexDefinitionCstNode)[]; columnDefinition?: (ColumnDefinitionCstNode)[]; Like?: IToken[]; - qualifiedName?: QualifiedNameCstNode[]; + tableName?: TableNameCstNode[]; Timestamp?: IToken[]; columnRef?: ColumnRefCstNode[]; Partition?: IToken[]; @@ -1012,8 +1060,7 @@ export interface AlterTableStatementCstNode extends CstNode { export type AlterTableStatementCstChildren = { Table: IToken[]; - qualifiedName?: QualifiedNameCstNode[]; - StringLiteral?: IToken[]; + tableNameOrString: TableNameOrStringCstNode[]; alterTableAction: AlterTableActionCstNode[]; }; @@ -1031,7 +1078,7 @@ export type AlterTableActionCstChildren = { columnDefinition?: (ColumnDefinitionCstNode)[]; Comma?: (IToken)[]; Drop?: (IToken)[]; - identifier?: (IdentifierCstNode)[]; + columnRef?: (ColumnRefCstNode)[]; Partition?: (IToken)[]; List?: (IToken)[]; StringLiteral?: (IToken)[]; @@ -1039,6 +1086,7 @@ export type AlterTableActionCstChildren = { expression?: (ExpressionCstNode)[]; Rename?: IToken[]; To?: IToken[]; + identifier?: (IdentifierCstNode)[]; Alter?: IToken[]; Type?: (IToken)[]; dataType?: DataTypeCstNode[]; @@ -1101,7 +1149,7 @@ export interface AlterMaterializedViewStatementCstNode extends CstNode { export type AlterMaterializedViewStatementCstChildren = { Materialized: IToken[]; View: IToken[]; - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; alterMaterializedViewAction: AlterMaterializedViewActionCstNode[]; }; @@ -1113,7 +1161,7 @@ export interface AlterMaterializedViewActionCstNode extends CstNode { export type AlterMaterializedViewActionCstChildren = { Alter?: IToken[]; Column?: IToken[]; - identifier?: IdentifierCstNode[]; + columnRef?: ColumnRefCstNode[]; Add?: IToken[]; Index?: (IToken)[]; Capacity?: (IToken)[]; @@ -1174,7 +1222,7 @@ export type DropTableStatementCstChildren = { Table?: IToken[]; If?: IToken[]; Exists?: IToken[]; - qualifiedName?: QualifiedNameCstNode[]; + tableName?: TableNameCstNode[]; }; export interface DropMaterializedViewStatementCstNode extends CstNode { @@ -1238,7 +1286,7 @@ export type TruncateTableStatementCstChildren = { If?: IToken[]; Exists?: IToken[]; Only?: IToken[]; - qualifiedName: (QualifiedNameCstNode)[]; + tableName: (TableNameCstNode)[]; Comma?: IToken[]; Keep?: IToken[]; Symbol?: IToken[]; @@ -1253,7 +1301,7 @@ export interface RenameTableStatementCstNode extends CstNode { export type RenameTableStatementCstChildren = { Rename: IToken[]; Table: IToken[]; - stringOrQualifiedName: (StringOrQualifiedNameCstNode)[]; + tableNameOrString: (TableNameOrStringCstNode)[]; To: IToken[]; }; @@ -1403,7 +1451,7 @@ export interface CopyFromCstNode extends CstNode { } export type CopyFromCstChildren = { - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; From: IToken[]; stringOrIdentifier: StringOrIdentifierCstNode[]; copyOptions?: CopyOptionsCstNode[]; @@ -1418,7 +1466,7 @@ export type CopyToCstChildren = { LParen?: IToken[]; selectStatement?: SelectStatementCstNode[]; RParen?: IToken[]; - qualifiedName?: QualifiedNameCstNode[]; + tableName?: TableNameCstNode[]; To: IToken[]; stringOrIdentifier: StringOrIdentifierCstNode[]; copyOptions?: CopyOptionsCstNode[]; @@ -1507,7 +1555,7 @@ export type BackupStatementCstChildren = { Backup: IToken[]; Database?: IToken[]; Table?: IToken[]; - qualifiedName?: QualifiedNameCstNode[]; + tableName?: TableNameCstNode[]; Abort?: IToken[]; }; @@ -1519,7 +1567,7 @@ export interface CompileViewStatementCstNode extends CstNode { export type CompileViewStatementCstChildren = { Compile: IToken[]; View: IToken[]; - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; }; export interface GrantStatementCstNode extends CstNode { @@ -1604,7 +1652,7 @@ export interface GrantTableTargetCstNode extends CstNode { } export type GrantTableTargetCstChildren = { - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; LParen?: IToken[]; identifier?: (IdentifierCstNode)[]; Comma?: IToken[]; @@ -1649,7 +1697,7 @@ export interface VacuumTableStatementCstNode extends CstNode { export type VacuumTableStatementCstChildren = { Vacuum: IToken[]; Table: IToken[]; - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; }; export interface ResumeWalStatementCstNode extends CstNode { @@ -1687,7 +1735,7 @@ export interface ReindexTableStatementCstNode extends CstNode { export type ReindexTableStatementCstChildren = { Reindex: IToken[]; Table: IToken[]; - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; Column?: IToken[]; identifier?: (IdentifierCstNode)[]; Comma?: (IToken)[]; @@ -1706,7 +1754,7 @@ export type RefreshMaterializedViewStatementCstChildren = { Refresh: IToken[]; Materialized: IToken[]; View: IToken[]; - qualifiedName: QualifiedNameCstNode[]; + tableName: TableNameCstNode[]; Full?: IToken[]; Incremental?: IToken[]; Range?: IToken[]; @@ -1724,7 +1772,7 @@ export type PivotStatementCstChildren = { LParen: (IToken)[]; selectStatement?: SelectStatementCstNode[]; RParen: (IToken)[]; - qualifiedName?: QualifiedNameCstNode[]; + tableName?: TableNameCstNode[]; whereClause?: WhereClauseCstNode[]; Pivot: IToken[]; pivotBody: PivotBodyCstNode[]; @@ -2306,6 +2354,25 @@ export type ColumnRefCstChildren = { qualifiedName: QualifiedNameCstNode[]; }; +export interface TableNameCstNode extends CstNode { + name: "tableName"; + children: TableNameCstChildren; +} + +export type TableNameCstChildren = { + qualifiedName: QualifiedNameCstNode[]; +}; + +export interface TableNameOrStringCstNode extends CstNode { + name: "tableNameOrString"; + children: TableNameOrStringCstChildren; +} + +export type TableNameOrStringCstChildren = { + StringLiteral?: IToken[]; + tableName?: TableNameCstNode[]; +}; + export interface QualifiedNameCstNode extends CstNode { name: "qualifiedName"; children: QualifiedNameCstChildren; @@ -2347,6 +2414,10 @@ export interface ICstNodeVisitor extends ICstVisitor { tableFunctionCall(children: TableFunctionCallCstChildren, param?: IN): OUT; tableFunctionName(children: TableFunctionNameCstChildren, param?: IN): OUT; joinClause(children: JoinClauseCstChildren, param?: IN): OUT; + asofLtJoin(children: AsofLtJoinCstChildren, param?: IN): OUT; + spliceJoin(children: SpliceJoinCstChildren, param?: IN): OUT; + windowJoin(children: WindowJoinCstChildren, param?: IN): OUT; + standardJoin(children: StandardJoinCstChildren, param?: IN): OUT; windowJoinBound(children: WindowJoinBoundCstChildren, param?: IN): OUT; durationExpression(children: DurationExpressionCstChildren, param?: IN): OUT; whereClause(children: WhereClauseCstChildren, param?: IN): OUT; @@ -2478,6 +2549,8 @@ export interface ICstNodeVisitor extends ICstVisitor { intervalValue(children: IntervalValueCstChildren, param?: IN): OUT; timeZoneValue(children: TimeZoneValueCstChildren, param?: IN): OUT; columnRef(children: ColumnRefCstChildren, param?: IN): OUT; + tableName(children: TableNameCstChildren, param?: IN): OUT; + tableNameOrString(children: TableNameOrStringCstChildren, param?: IN): OUT; qualifiedName(children: QualifiedNameCstChildren, param?: IN): OUT; identifier(children: IdentifierCstChildren, param?: IN): OUT; } diff --git a/src/parser/parser.ts b/src/parser/parser.ts index a07706c..0544ab7 100644 --- a/src/parser/parser.ts +++ b/src/parser/parser.ts @@ -753,7 +753,7 @@ class QuestDBParser extends CstParser { ALT: () => this.CONSUME(VariableReference), }, { ALT: () => this.CONSUME(StringLiteral) }, - { ALT: () => this.SUBRULE(this.qualifiedName) }, + { ALT: () => this.SUBRULE(this.tableName) }, ]) // Optional TIMESTAMP designation on subquery/table results this.OPTION2(() => { @@ -811,44 +811,67 @@ class QuestDBParser extends CstParser { this.SUBRULE(this.identifier) }) + // ---- Join clause: dispatches to type-specific sub-rules so that each + // join type only offers its own valid postamble tokens. private joinClause = this.RULE("joinClause", () => { - this.OPTION(() => { - this.OR([ - { ALT: () => this.CONSUME(Inner) }, - { ALT: () => this.CONSUME(Left) }, - { ALT: () => this.CONSUME(Right) }, - { ALT: () => this.CONSUME(Full) }, - { ALT: () => this.CONSUME(Cross) }, - { ALT: () => this.CONSUME(Asof) }, - { ALT: () => this.CONSUME(Lt) }, - { ALT: () => this.CONSUME(Splice) }, - { ALT: () => this.CONSUME(Window) }, - { ALT: () => this.CONSUME(Prevailing) }, - ]) - this.OPTION1(() => this.CONSUME(Outer)) - }) + this.OR([ + { ALT: () => this.SUBRULE(this.asofLtJoin) }, + { ALT: () => this.SUBRULE(this.spliceJoin) }, + { ALT: () => this.SUBRULE(this.windowJoin) }, + { ALT: () => this.SUBRULE(this.standardJoin) }, + ]) + }) + + // ASOF/LT JOIN: ON + TOLERANCE + private asofLtJoin = this.RULE("asofLtJoin", () => { + this.OR([ + { ALT: () => this.CONSUME(Asof) }, + { ALT: () => this.CONSUME(Lt) }, + ]) this.CONSUME(Join) this.SUBRULE(this.tableRef) - this.OPTION2(() => { + this.OPTION(() => { this.CONSUME(On) this.SUBRULE(this.expression) }) - // TOLERANCE clause for ASOF and LT joins (QuestDB-specific) - this.OPTION3(() => { + this.OPTION1(() => { this.CONSUME(Tolerance) this.CONSUME(DurationLiteral) }) - // RANGE BETWEEN clause for WINDOW JOIN - this.OPTION4(() => { + }) + + // SPLICE JOIN: ON only + private spliceJoin = this.RULE("spliceJoin", () => { + this.CONSUME(Splice) + this.CONSUME(Join) + this.SUBRULE(this.tableRef) + this.OPTION(() => { + this.CONSUME(On) + this.SUBRULE(this.expression) + }) + }) + + // WINDOW/PREVAILING JOIN: ON + RANGE BETWEEN + INCLUDE/EXCLUDE PREVAILING + private windowJoin = this.RULE("windowJoin", () => { + this.OR([ + { ALT: () => this.CONSUME(Window) }, + { ALT: () => this.CONSUME(Prevailing) }, + ]) + this.CONSUME(Join) + this.SUBRULE(this.tableRef) + this.OPTION(() => { + this.CONSUME(On) + this.SUBRULE(this.expression) + }) + this.OPTION1(() => { this.CONSUME(Range) this.CONSUME(Between) this.SUBRULE(this.windowJoinBound) this.CONSUME(And) this.SUBRULE1(this.windowJoinBound) }) - // INCLUDE/EXCLUDE PREVAILING clause for WINDOW JOIN - this.OPTION5(() => { - this.OR3([ + this.OPTION2(() => { + this.OR1([ { ALT: () => this.CONSUME(Include) }, { ALT: () => this.CONSUME(Exclude) }, ]) @@ -856,6 +879,32 @@ class QuestDBParser extends CstParser { }) }) + // Standard joins: (INNER | LEFT [OUTER] | RIGHT [OUTER] | FULL [OUTER] | CROSS)? JOIN + ON + private standardJoin = this.RULE("standardJoin", () => { + this.OPTION(() => { + this.OR([ + { + ALT: () => { + this.OR1([ + { ALT: () => this.CONSUME(Left) }, + { ALT: () => this.CONSUME(Right) }, + { ALT: () => this.CONSUME(Full) }, + ]) + this.OPTION1(() => this.CONSUME(Outer)) + }, + }, + { ALT: () => this.CONSUME(Inner) }, + { ALT: () => this.CONSUME(Cross) }, + ]) + }) + this.CONSUME(Join) + this.SUBRULE(this.tableRef) + this.OPTION2(() => { + this.CONSUME(On) + this.SUBRULE(this.expression) + }) + }) + // Window join bound: PRECEDING/FOLLOWING | CURRENT ROW [PRECEDING/FOLLOWING] | DurationLiteral PRECEDING/FOLLOWING private windowJoinBound = this.RULE("windowJoinBound", () => { this.OR([ @@ -1090,7 +1139,7 @@ class QuestDBParser extends CstParser { ]) }) this.CONSUME(Into) - this.SUBRULE(this.stringOrQualifiedName) + this.SUBRULE(this.tableNameOrString) // Batch clause can also appear after table name this.OPTION2(() => this.SUBRULE1(this.batchClause)) this.OPTION3(() => { @@ -1133,7 +1182,7 @@ class QuestDBParser extends CstParser { private updateStatement = this.RULE("updateStatement", () => { this.CONSUME(Update) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) // Optional alias this.OPTION2(() => this.SUBRULE(this.identifier)) this.CONSUME(Set) @@ -1296,7 +1345,7 @@ class QuestDBParser extends CstParser { ALT: () => { this.CONSUME2(LParen) this.CONSUME(Like) - this.SUBRULE2(this.qualifiedName) + this.SUBRULE(this.tableName) this.CONSUME2(RParen) }, }, @@ -1911,10 +1960,7 @@ class QuestDBParser extends CstParser { private alterTableStatement = this.RULE("alterTableStatement", () => { this.CONSUME(Table) - this.OR([ - { ALT: () => this.SUBRULE(this.qualifiedName) }, - { ALT: () => this.CONSUME(StringLiteral) }, - ]) + this.SUBRULE(this.tableNameOrString) this.SUBRULE(this.alterTableAction) }) @@ -1946,10 +1992,10 @@ class QuestDBParser extends CstParser { { ALT: () => { this.CONSUME1(Column) - this.SUBRULE(this.identifier) + this.SUBRULE(this.columnRef) this.MANY1(() => { this.CONSUME1(Comma) - this.SUBRULE1(this.identifier) + this.SUBRULE1(this.columnRef) }) }, }, @@ -1985,7 +2031,7 @@ class QuestDBParser extends CstParser { ALT: () => { this.CONSUME(Rename) this.CONSUME2(Column) - this.SUBRULE2(this.identifier) + this.SUBRULE2(this.columnRef) this.CONSUME(To) this.SUBRULE3(this.identifier) }, @@ -1995,7 +2041,7 @@ class QuestDBParser extends CstParser { ALT: () => { this.CONSUME1(Alter) this.CONSUME3(Column) - this.SUBRULE4(this.identifier) + this.SUBRULE3(this.columnRef) this.OR9([ { ALT: () => { @@ -2220,7 +2266,7 @@ class QuestDBParser extends CstParser { () => { this.CONSUME(Materialized) this.CONSUME(View) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) this.SUBRULE(this.alterMaterializedViewAction) }, ) @@ -2233,7 +2279,7 @@ class QuestDBParser extends CstParser { ALT: () => { this.CONSUME(Alter) this.CONSUME(Column) - this.SUBRULE(this.identifier) + this.SUBRULE(this.columnRef) this.OR1([ { ALT: () => { @@ -2384,7 +2430,7 @@ class QuestDBParser extends CstParser { this.CONSUME(If) this.CONSUME(Exists) }) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) }, }, ]) @@ -2446,10 +2492,10 @@ class QuestDBParser extends CstParser { this.CONSUME(Exists) }) this.OPTION1(() => this.CONSUME(Only)) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) this.MANY(() => { this.CONSUME(Comma) - this.SUBRULE1(this.qualifiedName) + this.SUBRULE1(this.tableName) }) this.OPTION2(() => { this.CONSUME(Keep) @@ -2465,9 +2511,9 @@ class QuestDBParser extends CstParser { private renameTableStatement = this.RULE("renameTableStatement", () => { this.CONSUME(Rename) this.CONSUME(Table) - this.SUBRULE(this.stringOrQualifiedName) + this.SUBRULE(this.tableNameOrString) this.CONSUME(To) - this.SUBRULE1(this.stringOrQualifiedName) + this.SUBRULE1(this.tableNameOrString) }) private addUserStatement = this.RULE("addUserStatement", () => { @@ -2690,7 +2736,7 @@ class QuestDBParser extends CstParser { }) private copyFrom = this.RULE("copyFrom", () => { - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) this.CONSUME(From) this.SUBRULE(this.stringOrIdentifier) this.OPTION(() => this.SUBRULE(this.copyOptions)) @@ -2705,7 +2751,7 @@ class QuestDBParser extends CstParser { this.CONSUME(RParen) }, }, - { ALT: () => this.SUBRULE(this.qualifiedName) }, + { ALT: () => this.SUBRULE(this.tableName) }, ]) this.CONSUME(To) this.SUBRULE1(this.stringOrIdentifier) @@ -2861,7 +2907,7 @@ class QuestDBParser extends CstParser { { ALT: () => { this.CONSUME(Table) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) }, }, { ALT: () => this.CONSUME(Abort) }, @@ -2875,7 +2921,7 @@ class QuestDBParser extends CstParser { private compileViewStatement = this.RULE("compileViewStatement", () => { this.CONSUME(Compile) this.CONSUME(View) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) }) // ========================================================================== @@ -2999,7 +3045,7 @@ class QuestDBParser extends CstParser { }) private grantTableTarget = this.RULE("grantTableTarget", () => { - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) this.OPTION(() => { this.CONSUME(LParen) this.SUBRULE(this.identifier) @@ -3049,7 +3095,7 @@ class QuestDBParser extends CstParser { private vacuumTableStatement = this.RULE("vacuumTableStatement", () => { this.CONSUME(Vacuum) this.CONSUME(Table) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) }) private resumeWalStatement = this.RULE("resumeWalStatement", () => { @@ -3078,7 +3124,7 @@ class QuestDBParser extends CstParser { private reindexTableStatement = this.RULE("reindexTableStatement", () => { this.CONSUME(Reindex) this.CONSUME(Table) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) this.OPTION(() => { this.CONSUME(Column) this.SUBRULE(this.identifier) @@ -3111,7 +3157,7 @@ class QuestDBParser extends CstParser { this.CONSUME(Refresh) this.CONSUME(Materialized) this.CONSUME(View) - this.SUBRULE(this.qualifiedName) + this.SUBRULE(this.tableName) this.OPTION(() => { this.OR([ { ALT: () => this.CONSUME(Full) }, @@ -3143,7 +3189,7 @@ class QuestDBParser extends CstParser { this.CONSUME1(RParen) }, }, - { ALT: () => this.SUBRULE(this.qualifiedName) }, + { ALT: () => this.SUBRULE(this.tableName) }, ]) this.OPTION(() => this.SUBRULE(this.whereClause)) this.CONSUME(Pivot) @@ -3966,6 +4012,22 @@ class QuestDBParser extends CstParser { this.SUBRULE(this.qualifiedName) }) + // Wrapper for qualifiedName in table/view name positions. + // Mirrors columnRef but for table references, so computeContentAssist + // ruleStack includes "tableName" → autocomplete suggests existing tables. + private tableName = this.RULE("tableName", () => { + this.SUBRULE(this.qualifiedName) + }) + + // Accepts StringLiteral or tableName. Used for table references that + // can be quoted (INSERT INTO, RENAME TABLE, ALTER TABLE). + private tableNameOrString = this.RULE("tableNameOrString", () => { + this.OR([ + { ALT: () => this.CONSUME(StringLiteral) }, + { ALT: () => this.SUBRULE(this.tableName) }, + ]) + }) + private qualifiedName = this.RULE("qualifiedName", () => { this.SUBRULE(this.identifier) this.MANY(() => { diff --git a/src/parser/visitor.ts b/src/parser/visitor.ts index ce7ee4b..fae9b7e 100644 --- a/src/parser/visitor.ts +++ b/src/parser/visitor.ts @@ -28,6 +28,7 @@ import type { ArrayElementCstChildren, ArrayLiteralCstChildren, ArraySubscriptCstChildren, + AsofLtJoinCstChildren, AssumeServiceAccountStatementCstChildren, BackupStatementCstChildren, BatchClauseCstChildren, @@ -137,12 +138,16 @@ import type { ShowStatementCstChildren, SimpleSelectCstChildren, SnapshotStatementCstChildren, + SpliceJoinCstChildren, + StandardJoinCstChildren, StatementCstChildren, StatementsCstChildren, StringOrIdentifierCstChildren, StringOrQualifiedNameCstChildren, TableFunctionCallCstChildren, TableFunctionNameCstChildren, + TableNameCstChildren, + TableNameOrStringCstChildren, TableParamCstChildren, TableParamNameCstChildren, TableRefCstChildren, @@ -159,6 +164,7 @@ import type { WindowFrameBoundCstChildren, WindowFrameClauseCstChildren, WindowJoinBoundCstChildren, + WindowJoinCstChildren, WindowPartitionByClauseCstChildren, WithClauseCstChildren, WithStatementCstChildren, @@ -596,7 +602,7 @@ class QuestDBVisitor extends BaseVisitor { parts: [ctx.StringLiteral[0].image.slice(1, -1)], } as AST.QualifiedName } else { - table = this.visit(ctx.qualifiedName!) as AST.QualifiedName + table = this.visit(ctx.tableName!) as AST.QualifiedName } const result: AST.TableRef = { @@ -650,48 +656,75 @@ class QuestDBVisitor extends BaseVisitor { } joinClause(ctx: JoinClauseCstChildren): AST.JoinClause { + if (ctx.asofLtJoin) return this.visit(ctx.asofLtJoin) as AST.JoinClause + if (ctx.spliceJoin) return this.visit(ctx.spliceJoin) as AST.JoinClause + if (ctx.windowJoin) return this.visit(ctx.windowJoin) as AST.JoinClause + return this.visit(ctx.standardJoin!) as AST.JoinClause + } + + asofLtJoin(ctx: AsofLtJoinCstChildren): AST.JoinClause { const result: AST.JoinClause = { type: "join", table: this.visit(ctx.tableRef) as AST.TableRef, } - - // Determine join type - if (ctx.Inner) result.joinType = "inner" - else if (ctx.Left) result.joinType = "left" - else if (ctx.Right) result.joinType = "right" - else if (ctx.Full) result.joinType = "full" - else if (ctx.Cross) result.joinType = "cross" - else if (ctx.Asof) result.joinType = "asof" + if (ctx.Asof) result.joinType = "asof" else if (ctx.Lt) result.joinType = "lt" - else if (ctx.Splice) result.joinType = "splice" - else if (ctx.Window) result.joinType = "window" - - if (ctx.Outer) { - result.outer = true - } - if (ctx.expression) { result.on = this.visit(ctx.expression) as AST.Expression } - - // Handle TOLERANCE clause for ASOF/LT joins if (ctx.DurationLiteral) { result.tolerance = ctx.DurationLiteral[0].image } + return result + } + + spliceJoin(ctx: SpliceJoinCstChildren): AST.JoinClause { + const result: AST.JoinClause = { + type: "join", + joinType: "splice", + table: this.visit(ctx.tableRef) as AST.TableRef, + } + if (ctx.expression) { + result.on = this.visit(ctx.expression) as AST.Expression + } + return result + } - // Handle RANGE BETWEEN clause for WINDOW JOIN + windowJoin(ctx: WindowJoinCstChildren): AST.JoinClause { + const result: AST.JoinClause = { + type: "join", + table: this.visit(ctx.tableRef) as AST.TableRef, + } + if (ctx.Window) result.joinType = "window" + if (ctx.expression) { + result.on = this.visit(ctx.expression) as AST.Expression + } if (ctx.windowJoinBound && ctx.windowJoinBound.length >= 2) { result.range = { start: this.visit(ctx.windowJoinBound[0]) as AST.WindowJoinBound, end: this.visit(ctx.windowJoinBound[1]) as AST.WindowJoinBound, } } - - // Handle INCLUDE/EXCLUDE PREVAILING clause for WINDOW JOIN - if (ctx.Prevailing) { + if (ctx.Include || ctx.Exclude) { result.prevailing = ctx.Include ? "include" : "exclude" } + return result + } + standardJoin(ctx: StandardJoinCstChildren): AST.JoinClause { + const result: AST.JoinClause = { + type: "join", + table: this.visit(ctx.tableRef) as AST.TableRef, + } + if (ctx.Inner) result.joinType = "inner" + else if (ctx.Left) result.joinType = "left" + else if (ctx.Right) result.joinType = "right" + else if (ctx.Full) result.joinType = "full" + else if (ctx.Cross) result.joinType = "cross" + if (ctx.Outer) result.outer = true + if (ctx.expression) { + result.on = this.visit(ctx.expression) as AST.Expression + } return result } @@ -865,7 +898,7 @@ class QuestDBVisitor extends BaseVisitor { insertStatement(ctx: InsertStatementCstChildren): AST.InsertStatement { const result: AST.InsertStatement = { type: "insert", - table: this.visit(ctx.stringOrQualifiedName) as AST.QualifiedName, + table: this.visit(ctx.tableNameOrString) as AST.QualifiedName, } if (ctx.Atomic) { @@ -923,7 +956,7 @@ class QuestDBVisitor extends BaseVisitor { updateStatement(ctx: UpdateStatementCstChildren): AST.UpdateStatement { const result: AST.UpdateStatement = { type: "update", - table: this.visit(ctx.qualifiedName) as AST.QualifiedName, + table: this.visit(ctx.tableName) as AST.QualifiedName, set: ctx.setClause.map((s: CstNode) => this.visit(s) as AST.SetClause), } @@ -1039,8 +1072,8 @@ class QuestDBVisitor extends BaseVisitor { ) } - if (ctx.Like && ctx.qualifiedName) { - result.like = this.visit(ctx.qualifiedName[0]) as AST.QualifiedName + if (ctx.Like && ctx.tableName) { + result.like = this.visit(ctx.tableName[0]) as AST.QualifiedName } if (ctx.selectStatement) { @@ -1521,12 +1554,7 @@ class QuestDBVisitor extends BaseVisitor { alterTableStatement( ctx: AlterTableStatementCstChildren, ): AST.AlterTableStatement { - const table = ctx.qualifiedName - ? (this.visit(ctx.qualifiedName) as AST.QualifiedName) - : { - type: "qualifiedName" as const, - parts: [ctx.StringLiteral![0].image.slice(1, -1)], - } + const table = this.visit(ctx.tableNameOrString) as AST.QualifiedName return { type: "alterTable", table, @@ -1539,7 +1567,7 @@ class QuestDBVisitor extends BaseVisitor { ): AST.AlterMaterializedViewStatement { return { type: "alterMaterializedView", - view: this.visit(ctx.qualifiedName) as AST.QualifiedName, + view: this.visit(ctx.tableName) as AST.QualifiedName, action: this.visit( ctx.alterMaterializedViewAction, ) as AST.AlterMaterializedViewAction, @@ -1550,9 +1578,10 @@ class QuestDBVisitor extends BaseVisitor { ctx: AlterMaterializedViewActionCstChildren, ): AST.AlterMaterializedViewAction { if (ctx.Add && ctx.Index) { + const colRef = this.visit(ctx.columnRef![0]) as AST.ColumnRef const result: AST.AlterMaterializedViewAddIndex = { actionType: "addIndex", - column: this.extractIdentifierName(ctx.identifier![0].children), + column: colRef.name.parts[colRef.name.parts.length - 1], } if (ctx.Capacity && ctx.NumberLiteral) { result.capacity = parseInt(ctx.NumberLiteral[0].image, 10) @@ -1561,9 +1590,10 @@ class QuestDBVisitor extends BaseVisitor { } if (ctx.Symbol && ctx.Capacity) { + const colRef = this.visit(ctx.columnRef![0]) as AST.ColumnRef return { actionType: "symbolCapacity", - column: this.extractIdentifierName(ctx.identifier![0].children), + column: colRef.name.parts[colRef.name.parts.length - 1], capacity: parseInt(ctx.NumberLiteral![0].image, 10), } } @@ -1584,9 +1614,10 @@ class QuestDBVisitor extends BaseVisitor { // ALTER COLUMN x DROP INDEX if (ctx.Alter && ctx.Drop && ctx.Index) { + const colRef = this.visit(ctx.columnRef![0]) as AST.ColumnRef return { actionType: "dropIndex", - column: this.extractIdentifierName(ctx.identifier![0].children), + column: colRef.name.parts[colRef.name.parts.length - 1], } } @@ -1700,30 +1731,32 @@ class QuestDBVisitor extends BaseVisitor { } // DROP COLUMN (when Drop token exists but no Partition and no Alter — Alter + Drop = ALTER COLUMN DROP INDEX) - if (ctx.Drop && ctx.identifier && !ctx.Partition && !ctx.Alter) { + if (ctx.Drop && ctx.columnRef && !ctx.Partition && !ctx.Alter) { return { actionType: "dropColumn", - columns: ctx.identifier.map((id: IdentifierCstNode) => - this.extractIdentifierName(id.children), - ), + columns: ctx.columnRef.map((c: CstNode) => { + const ref = this.visit(c) as AST.ColumnRef + return ref.name.parts[ref.name.parts.length - 1] + }), } } // RENAME COLUMN if (ctx.Rename) { - const identifiers = ctx.identifier!.map((id: IdentifierCstNode) => - this.extractIdentifierName(id.children), - ) + const oldRef = this.visit(ctx.columnRef![0]) as AST.ColumnRef + const oldName = oldRef.name.parts[oldRef.name.parts.length - 1] + const newName = this.extractIdentifierName(ctx.identifier![0].children) return { actionType: "renameColumn", - oldName: identifiers[0], - newName: identifiers[1], + oldName, + newName, } } // ALTER COLUMN - if (ctx.Alter && ctx.identifier) { - const column = this.extractIdentifierName(ctx.identifier[0].children) + if (ctx.Alter && ctx.columnRef) { + const colRef = this.visit(ctx.columnRef[0]) as AST.ColumnRef + const column = colRef.name.parts[colRef.name.parts.length - 1] let alterType: | "type" | "addIndex" @@ -1944,7 +1977,7 @@ class QuestDBVisitor extends BaseVisitor { if (ctx.All) { result.allTables = true } else { - result.table = this.visit(ctx.qualifiedName!) as AST.QualifiedName + result.table = this.visit(ctx.tableName!) as AST.QualifiedName if (ctx.If) { result.ifExists = true } @@ -2006,7 +2039,7 @@ class QuestDBVisitor extends BaseVisitor { truncateTableStatement( ctx: TruncateTableStatementCstChildren, ): AST.TruncateTableStatement { - const tables = ctx.qualifiedName.map( + const tables = ctx.tableName.map( (qn) => this.visit(qn) as AST.QualifiedName, ) @@ -2039,8 +2072,8 @@ class QuestDBVisitor extends BaseVisitor { ): AST.RenameTableStatement { return { type: "renameTable", - from: this.visit(ctx.stringOrQualifiedName[0]) as AST.QualifiedName, - to: this.visit(ctx.stringOrQualifiedName[1]) as AST.QualifiedName, + from: this.visit(ctx.tableNameOrString[0]) as AST.QualifiedName, + to: this.visit(ctx.tableNameOrString[1]) as AST.QualifiedName, } } @@ -2297,7 +2330,7 @@ class QuestDBVisitor extends BaseVisitor { copyFrom(ctx: CopyFromCstChildren): AST.CopyFromStatement { const result: AST.CopyFromStatement = { type: "copyFrom", - table: this.visit(ctx.qualifiedName) as AST.QualifiedName, + table: this.visit(ctx.tableName) as AST.QualifiedName, file: this.extractMaybeString(ctx.stringOrIdentifier[0]), } if (ctx.copyOptions) { @@ -2309,7 +2342,7 @@ class QuestDBVisitor extends BaseVisitor { copyTo(ctx: CopyToCstChildren): AST.CopyToStatement { const source = ctx.selectStatement ? (this.visit(ctx.selectStatement[0]) as AST.SelectStatement) - : (this.visit(ctx.qualifiedName!) as AST.QualifiedName) + : (this.visit(ctx.tableName!) as AST.QualifiedName) const result: AST.CopyToStatement = { type: "copyTo", source, @@ -2444,7 +2477,7 @@ class QuestDBVisitor extends BaseVisitor { return { type: "backup", action: "table", - table: this.visit(ctx.qualifiedName!) as AST.QualifiedName, + table: this.visit(ctx.tableName!) as AST.QualifiedName, } } @@ -2453,7 +2486,7 @@ class QuestDBVisitor extends BaseVisitor { ): AST.CompileViewStatement { return { type: "compileView", - view: this.visit(ctx.qualifiedName) as AST.QualifiedName, + view: this.visit(ctx.tableName) as AST.QualifiedName, } } @@ -2582,7 +2615,7 @@ class QuestDBVisitor extends BaseVisitor { grantTableTarget(ctx: GrantTableTargetCstChildren): AST.GrantTableTarget { const result: AST.GrantTableTarget = { type: "grantTableTarget", - table: this.visit(ctx.qualifiedName) as AST.QualifiedName, + table: this.visit(ctx.tableName) as AST.QualifiedName, } if (ctx.identifier && ctx.identifier.length > 0) { result.columns = ctx.identifier.map((id: IdentifierCstNode) => @@ -2622,7 +2655,7 @@ class QuestDBVisitor extends BaseVisitor { ): AST.VacuumTableStatement { return { type: "vacuumTable", - table: this.visit(ctx.qualifiedName) as AST.QualifiedName, + table: this.visit(ctx.tableName) as AST.QualifiedName, } } @@ -2654,7 +2687,7 @@ class QuestDBVisitor extends BaseVisitor { ): AST.ReindexTableStatement { const result: AST.ReindexTableStatement = { type: "reindexTable", - table: this.visit(ctx.qualifiedName) as AST.QualifiedName, + table: this.visit(ctx.tableName) as AST.QualifiedName, } if (ctx.Column && ctx.identifier) { result.columns = ctx.identifier.map((id: IdentifierCstNode) => @@ -2677,7 +2710,7 @@ class QuestDBVisitor extends BaseVisitor { ): AST.RefreshMaterializedViewStatement { const result: AST.RefreshMaterializedViewStatement = { type: "refreshMaterializedView", - view: this.visit(ctx.qualifiedName) as AST.QualifiedName, + view: this.visit(ctx.tableName) as AST.QualifiedName, } if (ctx.Full) result.mode = "full" if (ctx.Incremental) result.mode = "incremental" @@ -2698,7 +2731,7 @@ class QuestDBVisitor extends BaseVisitor { pivotStatement(ctx: PivotStatementCstChildren): AST.PivotStatement { const source = ctx.selectStatement ? (this.visit(ctx.selectStatement[0]) as AST.SelectStatement) - : (this.visit(ctx.qualifiedName!) as AST.QualifiedName) + : (this.visit(ctx.tableName!) as AST.QualifiedName) const body: Partial = ctx.pivotBody ? (this.visit(ctx.pivotBody) as PivotBodyResult) : {} @@ -3746,6 +3779,24 @@ class QuestDBVisitor extends BaseVisitor { } } + tableName(ctx: TableNameCstChildren): AST.QualifiedName { + return this.visit(ctx.qualifiedName) as AST.QualifiedName + } + + tableNameOrString( + ctx: TableNameOrStringCstChildren, + ): AST.QualifiedName { + if (ctx.StringLiteral) { + return { + type: "qualifiedName", + parts: [ctx.StringLiteral[0].image.slice(1, -1)], + } + } + if (ctx.tableName) + return this.visit(ctx.tableName) as AST.QualifiedName + return { type: "qualifiedName", parts: [] } + } + qualifiedName(ctx: QualifiedNameCstChildren): AST.QualifiedName { const parts: string[] = ctx.identifier.map((id: CstNode) => { return this.extractIdentifierName(id.children) diff --git a/tests/autocomplete.test.ts b/tests/autocomplete.test.ts index 7647b8b..58c4fe6 100644 --- a/tests/autocomplete.test.ts +++ b/tests/autocomplete.test.ts @@ -673,6 +673,102 @@ describe("JOIN autocomplete", () => { }, ]) }) + + it("should NOT suggest OUTER after ASOF", () => { + const labels = getLabelsAt(provider, "SELECT * FROM trades ASOF ") + expect(labels).not.toContain("OUTER") + }) + + it("should suggest OUTER after LEFT", () => { + const labels = getLabelsAt(provider, "SELECT * FROM trades LEFT ") + expect(labels).toContain("OUTER") + expect(labels).toContain("JOIN") + }) + + it("should suggest join types after ON condition (chained joins)", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t ASOF JOIN quotes q ON (symbol) ", + ) + expect(labels).toContain("ASOF") + expect(labels).toContain("JOIN") + expect(labels).toContain("CROSS") + expect(labels).toContain("LEFT") + }) + }) + + describe("join-type-specific postamble suggestions", () => { + it("ASOF JOIN: should suggest ON and TOLERANCE, not INCLUDE/EXCLUDE/RANGE", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t ASOF JOIN quotes q ", + ) + expect(labels).toContain("ON") + expect(labels).toContain("TOLERANCE") + expect(labels).not.toContain("INCLUDE") + expect(labels).not.toContain("EXCLUDE") + expect(labels).not.toContain("RANGE") + }) + + it("LT JOIN: should suggest ON and TOLERANCE, not INCLUDE/EXCLUDE/RANGE", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t LT JOIN quotes q ", + ) + expect(labels).toContain("ON") + expect(labels).toContain("TOLERANCE") + expect(labels).not.toContain("INCLUDE") + expect(labels).not.toContain("EXCLUDE") + expect(labels).not.toContain("RANGE") + }) + + it("SPLICE JOIN: should suggest ON, not TOLERANCE/INCLUDE/EXCLUDE/RANGE", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t SPLICE JOIN quotes q ", + ) + expect(labels).toContain("ON") + expect(labels).not.toContain("TOLERANCE") + expect(labels).not.toContain("INCLUDE") + expect(labels).not.toContain("EXCLUDE") + expect(labels).not.toContain("RANGE") + }) + + it("WINDOW JOIN: should suggest ON, RANGE, INCLUDE, EXCLUDE, not TOLERANCE", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t WINDOW JOIN quotes q ", + ) + expect(labels).toContain("ON") + expect(labels).toContain("RANGE") + expect(labels).toContain("INCLUDE") + expect(labels).toContain("EXCLUDE") + expect(labels).not.toContain("TOLERANCE") + }) + + it("INNER JOIN: should suggest ON, not TOLERANCE/INCLUDE/EXCLUDE/RANGE", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t INNER JOIN quotes q ", + ) + expect(labels).toContain("ON") + expect(labels).not.toContain("TOLERANCE") + expect(labels).not.toContain("INCLUDE") + expect(labels).not.toContain("EXCLUDE") + expect(labels).not.toContain("RANGE") + }) + + it("LEFT JOIN: should suggest ON, not TOLERANCE/INCLUDE/EXCLUDE/RANGE", () => { + const labels = getLabelsAt( + provider, + "SELECT * FROM trades t LEFT JOIN quotes q ", + ) + expect(labels).toContain("ON") + expect(labels).not.toContain("TOLERANCE") + expect(labels).not.toContain("INCLUDE") + expect(labels).not.toContain("EXCLUDE") + expect(labels).not.toContain("RANGE") + }) }) }) @@ -1337,8 +1433,10 @@ describe("CTE autocomplete", () => { it("should suggest identifier after comma between CTEs", () => { const sql = "WITH cte AS (SELECT * FROM users LIMIT 10), " const labels = getLabelsAt(provider, sql) - // Should be able to type a new CTE name - expect(labels.length).toBeGreaterThan(0) + // This position expects a new CTE name — the grammar correctly identifies + // it as a newName position (not a column or table reference). + // No column/table suggestions expected; the user types a free-form name. + expect(labels.every((l) => l !== "symbol" && l !== "price")).toBe(true) }) it("should not leak inner CTE source table columns into outer scope", () => { @@ -2658,4 +2756,150 @@ describe("CTE autocomplete", () => { expect(columns).toHaveLength(0) }) }) + + // =========================================================================== + // Grammar-level tableName classification tests + // =========================================================================== + // These tests verify that the grammar-based `tableName` rule correctly + // classifies positions as table name vs column vs new name contexts. + + describe("grammar-level tableName classification", () => { + it("CREATE TABLE (LIKE |) suggests tables, not columns", () => { + const suggestions = provider.getSuggestions( + "CREATE TABLE mytable (LIKE ", + 27, + ) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(tables.length).toBeGreaterThan(0) + expect(tables.map((s) => s.label)).toContain("trades") + expect(columns).toHaveLength(0) + }) + + it("DROP TABLE suggests tables, not columns", () => { + const suggestions = provider.getSuggestions("DROP TABLE ", 11) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(tables.map((s) => s.label)).toContain("trades") + expect(columns).toHaveLength(0) + }) + + it("TRUNCATE TABLE suggests tables, not columns", () => { + const suggestions = provider.getSuggestions("TRUNCATE TABLE ", 15) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(tables.map((s) => s.label)).toContain("trades") + expect(columns).toHaveLength(0) + }) + + it("ALTER TABLE suggests tables, not columns", () => { + const suggestions = provider.getSuggestions("ALTER TABLE ", 12) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(tables.map((s) => s.label)).toContain("trades") + expect(columns).toHaveLength(0) + }) + + it("INSERT INTO suggests tables, not columns", () => { + const suggestions = provider.getSuggestions("INSERT INTO ", 12) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(tables.map((s) => s.label)).toContain("trades") + expect(columns).toHaveLength(0) + }) + + it("SELECT clause suggests columns, not just tables", () => { + const suggestions = provider.getSuggestions("SELECT FROM trades", 7) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(columns.map((s) => s.label)).toContain("symbol") + }) + + it("WHERE clause suggests columns", () => { + const suggestions = provider.getSuggestions( + "SELECT * FROM trades WHERE ", + 27, + ) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(columns.map((s) => s.label)).toContain("price") + }) + + it("CREATE TABLE column definition: no columns, no tables", () => { + const suggestions = provider.getSuggestions("CREATE TABLE mytable (", 22) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + expect(columns).toHaveLength(0) + expect(tables).toHaveLength(0) + }) + + it("INSERT VALUES: no columns", () => { + const suggestions = provider.getSuggestions( + "INSERT INTO trades VALUES (", + 27, + ) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(columns).toHaveLength(0) + }) + + it("VACUUM TABLE suggests tables, not columns", () => { + const suggestions = provider.getSuggestions("VACUUM TABLE ", 13) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(tables.map((s) => s.label)).toContain("trades") + expect(columns).toHaveLength(0) + }) + + it("COPY TO table position suggests tables, not columns", () => { + const suggestions = provider.getSuggestions("COPY ", 5) + const tables = suggestions.filter((s) => s.kind === SuggestionKind.Table) + expect(tables.map((s) => s.label)).toContain("trades") + }) + + it("ALTER TABLE trades ALTER COLUMN suggests columns", () => { + const sql = "ALTER TABLE trades ALTER COLUMN " + const suggestions = provider.getSuggestions(sql, sql.length) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(columns.map((s) => s.label)).toContain("symbol") + expect(columns.map((s) => s.label)).toContain("price") + }) + + it("ALTER TABLE trades DROP COLUMN suggests columns", () => { + const sql = "ALTER TABLE trades DROP COLUMN " + const suggestions = provider.getSuggestions(sql, sql.length) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(columns.map((s) => s.label)).toContain("symbol") + }) + + it("ALTER TABLE trades RENAME COLUMN suggests columns for old name", () => { + const sql = "ALTER TABLE trades RENAME COLUMN " + const suggestions = provider.getSuggestions(sql, sql.length) + const columns = suggestions.filter( + (s) => s.kind === SuggestionKind.Column, + ) + expect(columns.map((s) => s.label)).toContain("symbol") + }) + }) }) diff --git a/tests/docs-autocomplete.test.ts b/tests/docs-autocomplete.test.ts index 9dbb3ae..bb3cefb 100644 --- a/tests/docs-autocomplete.test.ts +++ b/tests/docs-autocomplete.test.ts @@ -1,19 +1,38 @@ /** * Documentation SQL - Autocomplete Walkthrough Tests * - * For each SQL statement, tokenizes it and walks through token-by-token, - * verifying that at each position the content-assist correctly predicts the next token. + * For each SQL statement from docs-queries.json, tokenizes it and walks + * through token-by-token, verifying that at each position the autocomplete + * provider's getSuggestions() includes the actual next word as a suggestion. * - * Queries with skipAutocomplete: true in the fixture are known edge cases in - * Chevrotain's computeContentAssist (implicit SELECT, SQL hints, semicolons, - * array slice colon syntax) — not actual autocomplete bugs. + * This tests the REAL end-to-end autocomplete behavior that users see. + * + * Per-query schema extraction: + * 1. Parse the query with parseToAst() to get the AST + * 2. Walk the AST to extract table names and column definitions + * 3. Create a provider with the extracted schema + * 4. Walk token-by-token, checking that each keyword and known identifier + * appears in the suggestion list + * + * Queries with skipAutocomplete: true are known edge cases (implicit SELECT, + * SQL hints, semicolons, array slice colon syntax). * * Source of truth: tests/fixtures/docs-queries.json */ import { describe, it } from "vitest" -import { tokenize, getNextValidTokens } from "../src/index" -import { IDENTIFIER_KEYWORD_TOKENS } from "../src/autocomplete/token-classification" +import { + tokenize, + createAutocompleteProvider, + parseToAst, + getNextValidTokens, +} from "../src/index" +import type { SchemaInfo } from "../src/autocomplete/types" +import type * as AST from "../src/parser/ast" +import { + tokenNameToKeyword, + IDENTIFIER_KEYWORD_TOKENS, +} from "../src/autocomplete/token-classification" import * as fs from "fs" import * as path from "path" @@ -27,39 +46,245 @@ const queries: DocsQuery[] = JSON.parse( fs.readFileSync(fixtureP, "utf-8"), ) as DocsQuery[] -function isTokenExpectedAtPosition( - actualTokenType: string, - expectedTokenNames: string[], -): boolean { - if (expectedTokenNames.includes(actualTokenType)) return true +// ============================================================================= +// Schema extraction from AST +// ============================================================================= + +interface ExtractedSchema { + tables: Set + columns: Map> // tableName → column names +} + +/** + * Extract table and column names from a parsed AST. + * Walks all nodes recursively to find table references and column definitions. + */ +function extractSchemaFromAst(statements: AST.Statement[]): ExtractedSchema { + const tables = new Set() + const columns = new Map>() + + function addColumn(table: string, col: string) { + const lower = table.toLowerCase() + if (!columns.has(lower)) columns.set(lower, new Set()) + columns.get(lower)!.add(col) + } + + function getTableName(name: AST.QualifiedName): string { + return name.parts[name.parts.length - 1] + } + + function walkNode(node: unknown): void { + if (!node || typeof node !== "object") return + const n = node as Record + + // Handle arrays + if (Array.isArray(node)) { + for (const item of node) walkNode(item) + return + } + + const type = n.type as string | undefined - if (expectedTokenNames.includes("Identifier")) { - if (IDENTIFIER_KEYWORD_TOKENS.has(actualTokenType)) return true - if (actualTokenType === "QuotedIdentifier") return true + // TableRef: extract table name from FROM/JOIN + if (type === "tableRef") { + const ref = node as AST.TableRef + if ( + ref.table && + (ref.table as AST.QualifiedName).type === "qualifiedName" + ) { + const name = getTableName(ref.table as AST.QualifiedName) + tables.add(name) + } + // Recurse into joins + if (ref.joins) walkNode(ref.joins) + // Recurse into subqueries + if (ref.table && (ref.table as AST.SelectStatement).type === "select") { + walkNode(ref.table) + } + return + } + + // JoinClause: extract table from joined table + if (type === "join") { + const join = node as AST.JoinClause + walkNode(join.table) + if (join.on) walkNode(join.on) + return + } + + // CreateTable: extract table name and column definitions + if (type === "createTable") { + const ct = node as AST.CreateTableStatement + const tableName = getTableName(ct.table) + tables.add(tableName) + if (ct.columns) { + for (const col of ct.columns) { + addColumn(tableName, col.name) + } + } + if (ct.asSelect) walkNode(ct.asSelect) + if (ct.like) tables.add(getTableName(ct.like)) + return + } + + // InsertStatement: extract table name and column names + if (type === "insert") { + const ins = node as AST.InsertStatement + const tableName = getTableName(ins.table) + tables.add(tableName) + if (ins.columns) { + for (const col of ins.columns) { + addColumn(tableName, col) + } + } + if (ins.select) walkNode(ins.select) + return + } + + // SelectStatement: recurse into all parts + if (type === "select") { + const sel = node as AST.SelectStatement + if (sel.from) walkNode(sel.from) + if (sel.columns) walkNode(sel.columns) + if (sel.where) walkNode(sel.where) + if (sel.with) { + for (const cte of sel.with) { + if (cte.name) tables.add(cte.name) + if (cte.query) walkNode(cte.query) + } + } + if (sel.groupBy) walkNode(sel.groupBy) + if (sel.orderBy) walkNode(sel.orderBy) + if (sel.sampleBy) walkNode(sel.sampleBy) + if (sel.latestOn) walkNode(sel.latestOn) + if (sel.declare) walkNode(sel.declare) + return + } + + // AlterTable: extract table name + if (type === "alterTable") { + const alt = node as AST.AlterTableStatement + tables.add(getTableName(alt.table)) + return + } + + // UpdateStatement: extract table name + if (type === "update") { + const upd = node as AST.UpdateStatement + tables.add(getTableName(upd.table)) + if (upd.from) walkNode(upd.from) + if (upd.where) walkNode(upd.where) + return + } + + // For all other node types, recurse into all object/array properties + for (const value of Object.values(n)) { + if (value && typeof value === "object") { + walkNode(value) + } + } } - // When the token limit fallback fires, computeSuggestions returns only - // [IdentifierKeyword]. This is a performance trade-off for large queries, - // not an autocomplete bug — accept any token in this degraded mode. - if ( - expectedTokenNames.length === 1 && - expectedTokenNames[0] === "IdentifierKeyword" - ) { - return true + for (const stmt of statements) { + walkNode(stmt) } - return false + return { tables, columns } } +/** + * Build a SchemaInfo from extracted table/column data + */ +function buildSchema(extracted: ExtractedSchema): SchemaInfo { + const tableList = Array.from(extracted.tables).map((name) => ({ name })) + const columnsMap: Record = {} + + for (const [tableName, cols] of extracted.columns) { + columnsMap[tableName] = Array.from(cols).map((name) => ({ + name, + type: "STRING", + })) + } + + return { tables: tableList, columns: columnsMap } +} + +// ============================================================================= +// Token classification +// ============================================================================= + +/** + * Token types that we cannot verify via suggestion labels: + * - Literals: user-provided values, not suggested + * - Punctuation and operators: not shown as keyword suggestions + * - Variable references: user-defined + */ +const SKIP_TOKEN_TYPES = new Set([ + // Literals + "StringLiteral", + "NumberLiteral", + "DurationLiteral", + "BooleanLiteral", + "NullLiteral", + "GeohashLiteral", + "GeohashBinaryLiteral", + "DecimalLiteral", + "LongLiteral", + // Punctuation + "LParen", + "RParen", + "Comma", + "Dot", + "Semicolon", + "LBracket", + "RBracket", + "AtSign", + "ColonEquals", + // Operators + "Equals", + "NotEquals", + "LessThan", + "LessThanOrEqual", + "GreaterThan", + "GreaterThanOrEqual", + "Plus", + "Minus", + "Star", + "Divide", + "Modulo", + "Concat", + "DoubleColon", + "RegexMatch", + "RegexNotMatch", + "RegexNotEquals", + "IPv4ContainedBy", + "IPv4Contains", + "BitXor", + "BitOr", + "BitAnd", + // Variable references + "VariableReference", + // Interval literals + "IntervalLiteral", +]) + +// ============================================================================= +// Walkthrough logic +// ============================================================================= + interface WalkthroughStep { position: number tokenImage: string tokenType: string - expectedTokens: string[] - isExpected: boolean + expectedLabel: string + suggestions: string[] + found: boolean } -function autocompleteWalkthrough(sql: string): { +function autocompleteWalkthrough( + sql: string, + schema: SchemaInfo, +): { success: boolean steps: WalkthroughStep[] failedSteps: WalkthroughStep[] @@ -69,31 +294,99 @@ function autocompleteWalkthrough(sql: string): { return { success: false, steps: [], failedSteps: [] } } + const provider = createAutocompleteProvider(schema) + + // Build a set of known table and column names for identifier matching + const knownNames = new Set() + for (const t of schema.tables) knownNames.add(t.name.toLowerCase()) + for (const cols of Object.values(schema.columns)) { + for (const c of cols) knownNames.add(c.name.toLowerCase()) + } + const steps: WalkthroughStep[] = [] for (let i = 0; i < tokens.length; i++) { const token = tokens[i] + const tokenType = token.tokenType.name + + // Skip literals, punctuation, operators + if (SKIP_TOKEN_TYPES.has(tokenType)) continue + const prefix = sql.substring(0, token.startOffset) - const expectedTokenNames = getNextValidTokens(prefix) - const actualType = token.tokenType.name - const isExpected = isTokenExpectedAtPosition(actualType, expectedTokenNames) + const labels = provider + .getSuggestions(prefix, prefix.length) + .map((s) => s.label) + + let found = false + let expectedLabel: string + + if (tokenType === "Identifier" || tokenType === "QuotedIdentifier") { + // Identifier token: check if it's a known table/column name (reference) + // or an unknown name (new table, alias, new column definition) → skip + const word = + tokenType === "QuotedIdentifier" + ? token.image.slice(1, -1).toLowerCase() + : token.image.toLowerCase() + if (!knownNames.has(word)) continue + + // If the provider suggests ANY known name at this position, it's a + // reference position (FROM, JOIN, WHERE, etc.) → verify our name is there. + // If it suggests none, it's a definition position (CREATE TABLE name, + // column definition, alias) → skip. + const suggestsAnyKnownName = labels.some((l) => + knownNames.has(l.toLowerCase()), + ) + if (!suggestsAnyKnownName) continue + + expectedLabel = word + found = labels.some((l) => l.toLowerCase() === word) + } else { + // Keyword token: check if keyword label appears in suggestions + expectedLabel = tokenNameToKeyword(tokenType) + const expectedUpper = expectedLabel.toUpperCase() + found = labels.some((l) => l.toUpperCase() === expectedUpper) + + // Fallback: keyword tokens used as identifiers (e.g., `timestamp` + // as a column name). Accept if the word is a known column/table in + // the schema OR IdentifierKeyword is expected at this position. + if (!found && IDENTIFIER_KEYWORD_TOKENS.has(tokenType)) { + const word = token.image.toLowerCase() + if (knownNames.has(word)) { + found = labels.some((l) => l.toLowerCase() === word) + } + if (!found) { + const rawTokens = getNextValidTokens(prefix) + if ( + rawTokens.includes("IdentifierKeyword") || + rawTokens.includes(tokenType) + ) { + found = true + } + } + } + } steps.push({ position: token.startOffset, tokenImage: token.image, - tokenType: actualType, - expectedTokens: expectedTokenNames, - isExpected, + tokenType, + expectedLabel, + suggestions: labels.slice(0, 15), + found, }) } return { success: true, steps, - failedSteps: steps.filter((s) => !s.isExpected), + failedSteps: steps.filter((s) => !s.found), } } +// ============================================================================= +// Test setup +// ============================================================================= + const testable = queries .map((q, i) => ({ ...q, index: i })) .filter((q) => !q.skipAutocomplete) @@ -107,7 +400,13 @@ describe("Documentation SQL - Autocomplete Walkthrough", () => { ]), )("%s", (_label, entry) => { const q = entry as DocsQuery & { index: number } - const result = autocompleteWalkthrough(q.query) + + // Extract schema from AST + const parseResult = parseToAst(q.query) + const extracted = extractSchemaFromAst(parseResult.ast) + const schema = buildSchema(extracted) + + const result = autocompleteWalkthrough(q.query, schema) if (!result.success) { return @@ -118,13 +417,13 @@ describe("Documentation SQL - Autocomplete Walkthrough", () => { .slice(0, 3) .map( (s) => - ` At offset ${s.position}: "${s.tokenImage}" (${s.tokenType}) ` + - `not in expected [${s.expectedTokens.slice(0, 10).join(", ")}${s.expectedTokens.length > 10 ? "..." : ""}]`, + ` At offset ${s.position}: expected "${s.expectedLabel}" (from token "${s.tokenImage}" [${s.tokenType}]) ` + + `not found in suggestions [${s.suggestions.join(", ")}${s.suggestions.length >= 15 ? "..." : ""}]`, ) .join("\n") throw new Error( - `Autocomplete walkthrough failed for #${q.index}:\n${failures}\n\nSQL: ${q.query.substring(0, 120)}`, + `Autocomplete walkthrough failed for #${q.index} (${result.failedSteps.length} failures):\n${failures}\n\nSQL: ${q.query.substring(0, 120)}`, ) } })