diff --git a/extension/java/CMakeLists.txt b/extension/java/CMakeLists.txt new file mode 100644 index 00000000000..f6fab8193c0 --- /dev/null +++ b/extension/java/CMakeLists.txt @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) + +# This can be built standalone or as part of the main ExecutorTorch build +if(NOT TARGET executorch) + project(executorch_java) + set(EXECUTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../..") + include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +endif() + +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +# Skip if building for Android (use extension/android instead) +if(ANDROID) + message(STATUS "Skipping extension/java on Android - use extension/android instead") + return() +endif() + +set(_common_compile_options + $<$:/wd4996> + $<$>:-Wno-deprecated-declarations -fPIC> +) + +# Find JNI - provide hints for common locations +# Users can set JAVA_HOME environment variable or -DJAVA_HOME= +if(DEFINED ENV{JAVA_HOME}) + set(JAVA_HOME $ENV{JAVA_HOME}) +endif() + +if(JAVA_HOME) + message(STATUS "Using JAVA_HOME: ${JAVA_HOME}") + set(JAVA_AWT_LIBRARY "${JAVA_HOME}/lib" CACHE PATH "AWT library path") + set(JAVA_JVM_LIBRARY "${JAVA_HOME}/lib" CACHE PATH "JVM library path") + set(JAVA_INCLUDE_PATH "${JAVA_HOME}/include" CACHE PATH "JNI include path") + + # Platform-specific include path for jni_md.h + if(APPLE) + set(JAVA_INCLUDE_PATH2 "${JAVA_HOME}/include/darwin" CACHE PATH "JNI platform include path") + elseif(WIN32) + set(JAVA_INCLUDE_PATH2 "${JAVA_HOME}/include/win32" CACHE PATH "JNI platform include path") + else() + set(JAVA_INCLUDE_PATH2 "${JAVA_HOME}/include/linux" CACHE PATH "JNI platform include path") + endif() + + # Set hints for FindJNI + set(JAVA_AWT_INCLUDE_PATH "${JAVA_HOME}/include" CACHE PATH "AWT include path") +endif() + +# Find JNI - we don't actually need AWT, so make it optional +set(JNI_FIND_REQUIRED_AWT FALSE) +find_package(JNI COMPONENTS JNI) + +if(NOT JNI_FOUND) + # Try again with explicit paths + if(JAVA_HOME) + set(JNI_INCLUDE_DIRS "${JAVA_HOME}/include") + if(APPLE) + list(APPEND JNI_INCLUDE_DIRS "${JAVA_HOME}/include/darwin") + elseif(WIN32) + list(APPEND JNI_INCLUDE_DIRS "${JAVA_HOME}/include/win32") + else() + list(APPEND JNI_INCLUDE_DIRS "${JAVA_HOME}/include/linux") + endif() + message(STATUS "JNI include dirs set manually: ${JNI_INCLUDE_DIRS}") + else() + message(FATAL_ERROR + "Could not find JNI. Please set JAVA_HOME environment variable or pass -DJAVA_HOME=\n" + "Example: cmake -DJAVA_HOME=/usr/lib/jvm/java-11-openjdk ..." + ) + endif() +endif() + +# Build fbjni from source using FetchContent +include(FetchContent) + +if(NOT FBJNI_VERSION) + set(FBJNI_VERSION 0.7.0) +endif() + +FetchContent_Declare( + fbjni + GIT_REPOSITORY https://github.com/facebookincubator/fbjni.git + GIT_TAG v${FBJNI_VERSION} +) + +# Configure fbjni build options +set(FBJNI_BUILD_TESTS OFF CACHE BOOL "" FORCE) + +FetchContent_MakeAvailable(fbjni) + +# Use shared JNI sources from the android directory +set(JNI_SOURCES_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../android/jni") + +add_library( + executorch_jni SHARED + ${JNI_SOURCES_DIR}/jni_layer.cpp + ${JNI_SOURCES_DIR}/log.cpp + ${JNI_SOURCES_DIR}/jni_layer_runtime.cpp + ${JNI_SOURCES_DIR}/jni_helper.cpp +) + +set(link_libraries) +list( + APPEND + link_libraries + executorch + extension_data_loader + extension_flat_tensor + extension_module + extension_runner_util + extension_tensor + extension_threadpool + fbjni +) + +if(EXECUTORCH_JAVA_PROFILING) + list(APPEND link_libraries etdump flatccrt) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_JAVA_PROFILING=1 + ) +endif() + +if(TARGET optimized_native_cpu_ops_lib) + list(APPEND link_libraries optimized_native_cpu_ops_lib) + executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) +elseif(TARGET portable_ops_lib) + list(APPEND link_libraries portable_ops_lib portable_kernels) + executorch_target_link_options_shared_lib(portable_ops_lib) +endif() + +if(TARGET quantized_kernels) + list(APPEND link_libraries quantized_kernels quantized_ops_lib) + executorch_target_link_options_shared_lib(quantized_ops_lib) +endif() + +if(TARGET xnnpack_backend) + executorch_target_link_options_shared_lib(xnnpack_backend) + list( + APPEND + link_libraries + xnnpack_backend + XNNPACK + pthreadpool + cpuinfo + xnnpack-microkernels-prod + ) + if(TARGET kleidiai) + list(APPEND link_libraries kleidiai) + endif() +endif() + +if(EXECUTORCH_BUILD_KERNELS_LLM) + list(APPEND link_libraries $) +endif() + +if(TARGET pthreadpool) + target_include_directories( + executorch_jni + PUBLIC + ${EXECUTORCH_ROOT}/backends/xnnpack/third-party/cpuinfo/include + ) + target_include_directories( + executorch_jni + PUBLIC + ${EXECUTORCH_ROOT}/backends/xnnpack/third-party/pthreadpool/include + ) +endif() + +if(EXECUTORCH_JNI_CUSTOM_LIBRARY) + list(APPEND link_libraries ${EXECUTORCH_JNI_CUSTOM_LIBRARY}) + target_link_libraries( + executorch_jni -Wl,--whole-archive ${EXECUTORCH_JNI_CUSTOM_LIBRARY} + -Wl,--no-whole-archive + ) +endif() + +if(EXECUTORCH_BUILD_EXTENSION_TRAINING) + target_sources(executorch_jni PRIVATE ${JNI_SOURCES_DIR}/jni_layer_training.cpp) + list(APPEND link_libraries extension_training) + target_compile_definitions( + executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_TRAINING=1 + ) +endif() + +if(EXECUTORCH_BUILD_LLAMA_JNI) + target_sources(executorch_jni PRIVATE ${JNI_SOURCES_DIR}/jni_layer_llama.cpp) + list(APPEND link_libraries extension_llm_runner) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) +endif() + +target_include_directories( + executorch_jni + PRIVATE + ${_common_include_directories} + ${JNI_INCLUDE_DIRS} + ${fbjni_SOURCE_DIR}/cxx +) + +target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) + +target_link_libraries(executorch_jni ${link_libraries}) + +# Platform-specific logging library +if(APPLE) + # No additional logging library needed on macOS +elseif(UNIX) + # No additional logging library needed on Linux +elseif(WIN32) + # No additional logging library needed on Windows +endif() diff --git a/extension/java/build.gradle b/extension/java/build.gradle new file mode 100644 index 00000000000..22bf5340501 --- /dev/null +++ b/extension/java/build.gradle @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +plugins { + id 'java-library' + id 'maven-publish' +} + +def execuTorchVersion = System.properties['execuTorchVersion'] ?: '1.0.0-SNAPSHOT' + +group = 'org.pytorch' +version = execuTorchVersion + +java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + + withSourcesJar() + withJavadocJar() +} + +repositories { + mavenCentral() +} + +dependencies { + testImplementation 'junit:junit:4.13.2' + testImplementation 'org.assertj:assertj-core:3.27.2' +} + +// Configure where to find native libraries for JAR packaging +def nativeLibDir = file("${rootDir}/build/native") + +// Task to copy native libraries into resources for JAR packaging +task copyNativeLibs(type: Copy) { + from nativeLibDir + into "${buildDir}/resources/main/native" + // Only run if native libs exist + onlyIf { nativeLibDir.exists() } +} + +processResources.dependsOn copyNativeLibs + +jar { + manifest { + attributes( + 'Implementation-Title': 'ExecuTorch Java', + 'Implementation-Version': version, + 'Automatic-Module-Name': 'org.pytorch.executorch' + ) + } + + // Include native libraries if they exist + from("${buildDir}/resources/main/native") { + into 'native' + } +} + +test { + useJUnit() + + testLogging { + events "passed", "skipped", "failed" + exceptionFormat "full" + } + + // Set library path for native libraries during tests + systemProperty 'java.library.path', "${nativeLibDir}/${getOsArch()}" +} + +// Helper function to determine OS/arch directory +def getOsArch() { + def osName = System.getProperty('os.name', '').toLowerCase() + def arch = System.getProperty('os.arch', '').toLowerCase() + + def os + if (osName.contains('mac') || osName.contains('darwin')) { + os = 'darwin' + } else if (osName.contains('win')) { + os = 'windows' + } else { + os = 'linux' + } + + if (arch == 'amd64' || arch == 'x86_64') { + arch = 'x86_64' + } else if (arch == 'aarch64' || arch == 'arm64') { + arch = 'aarch64' + } + + return "${os}-${arch}" +} + +publishing { + publications { + mavenJava(MavenPublication) { + from components.java + + pom { + name = 'ExecuTorch Java' + description = 'ExecuTorch Java API for desktop platforms (Linux, macOS, Windows)' + url = 'https://github.com/pytorch/executorch' + + licenses { + license { + name = 'BSD 3-Clause' + url = 'https://github.com/pytorch/executorch/blob/main/LICENSE' + } + } + + developers { + developer { + id = 'pytorch' + name = 'PyTorch Team' + url = 'https://github.com/pytorch/executorch' + } + } + + scm { + url = 'https://github.com/pytorch/executorch.git' + connection = 'scm:git:https://github.com/pytorch/executorch' + developerConnection = 'scm:git:git@github.com:pytorch/executorch.git' + } + } + } + } + + repositories { + maven { + name = 'local' + url = "${buildDir}/repo" + } + } +} + +// Custom task to build native libraries using CMake +task buildNative(type: Exec) { + description = 'Build native JNI libraries using CMake' + workingDir rootDir + + def cmakeBuildDir = "${project.rootDir}/../../cmake-out-java" + def executorchRoot = "${project.rootDir}/../.." + + commandLine 'bash', '-c', """ + mkdir -p ${cmakeBuildDir} && \\ + cd ${cmakeBuildDir} && \\ + cmake ${executorchRoot} \\ + -DCMAKE_BUILD_TYPE=Release \\ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \\ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \\ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \\ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \\ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \\ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \\ + -DEXECUTORCH_BUILD_XNNPACK=ON \\ + -DEXECUTORCH_BUILD_JAVA_JNI=ON && \\ + cmake --build . -j\$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4) && \\ + mkdir -p ${nativeLibDir}/${getOsArch()} && \\ + find ${cmakeBuildDir} -name 'libexecutorch_jni.*' -exec cp {} ${nativeLibDir}/${getOsArch()}/ \\; + """ +} + +// Convenience task to build everything +task buildAll { + description = 'Build native libraries and Java JAR' + dependsOn buildNative + dependsOn jar +} diff --git a/extension/java/build_java.sh b/extension/java/build_java.sh new file mode 100755 index 00000000000..c1841c48cdf --- /dev/null +++ b/extension/java/build_java.sh @@ -0,0 +1,246 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Build script for ExecuTorch Java package (desktop platforms) +# +# This script builds: +# 1. The native JNI library (libexecutorch_jni.so/dylib/dll) +# 2. The Java JAR file +# +# Usage: +# ./build_java.sh [options] +# +# Options: +# --cmake-only Only build native libraries with CMake +# --jar-only Only build the Java JAR (assumes native libs exist) +# --clean Clean build directories before building +# --help Show this help message + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXECUTORCH_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +CMAKE_BUILD_DIR="${EXECUTORCH_ROOT}/cmake-out-java" +NATIVE_LIB_DIR="${SCRIPT_DIR}/build/native" + +# Detect JAVA_HOME if not set +detect_java_home() { + if [ -n "${JAVA_HOME}" ]; then + echo "Using JAVA_HOME: ${JAVA_HOME}" + return + fi + + # Try to detect Java installation + local java_path + if command -v java &> /dev/null; then + java_path=$(command -v java) + # Follow symlinks to find actual Java installation + while [ -L "$java_path" ]; do + java_path=$(readlink "$java_path") + done + + # Go up from bin/java to the Java home + local java_bin_dir=$(dirname "$java_path") + local potential_home=$(dirname "$java_bin_dir") + + # Check if this looks like a valid JAVA_HOME + if [ -f "${potential_home}/include/jni.h" ]; then + export JAVA_HOME="${potential_home}" + echo "Detected JAVA_HOME: ${JAVA_HOME}" + return + fi + fi + + # macOS: Try /usr/libexec/java_home + if [ "$(uname -s)" = "Darwin" ] && command -v /usr/libexec/java_home &> /dev/null; then + export JAVA_HOME=$(/usr/libexec/java_home 2>/dev/null || true) + if [ -n "${JAVA_HOME}" ] && [ -f "${JAVA_HOME}/include/jni.h" ]; then + echo "Detected JAVA_HOME (macOS): ${JAVA_HOME}" + return + fi + fi + + # Linux: Common locations + for dir in /usr/lib/jvm/java-* /usr/lib/jvm/default-java /usr/java/latest; do + if [ -d "$dir" ] && [ -f "${dir}/include/jni.h" ]; then + export JAVA_HOME="$dir" + echo "Detected JAVA_HOME: ${JAVA_HOME}" + return + fi + done + + echo "Warning: JAVA_HOME not set and could not be auto-detected." + echo "Please set JAVA_HOME environment variable to your JDK installation." + echo "Example: export JAVA_HOME=/usr/lib/jvm/java-11-openjdk" +} + +detect_java_home + +# Determine OS and architecture +get_os_arch() { + local os arch + + case "$(uname -s)" in + Darwin*) + os="darwin" + ;; + Linux*) + os="linux" + ;; + CYGWIN*|MINGW*|MSYS*) + os="windows" + ;; + *) + os="linux" + ;; + esac + + case "$(uname -m)" in + x86_64|amd64) + arch="x86_64" + ;; + aarch64|arm64) + arch="aarch64" + ;; + *) + arch="$(uname -m)" + ;; + esac + + echo "${os}-${arch}" +} + +OS_ARCH=$(get_os_arch) + +# Number of parallel jobs +if command -v nproc &> /dev/null; then + NPROC=$(nproc) +elif command -v sysctl &> /dev/null; then + NPROC=$(sysctl -n hw.ncpu) +else + NPROC=4 +fi + +# Parse arguments +CMAKE_ONLY=false +JAR_ONLY=false +CLEAN=false + +while [[ $# -gt 0 ]]; do + case $1 in + --cmake-only) + CMAKE_ONLY=true + shift + ;; + --jar-only) + JAR_ONLY=true + shift + ;; + --clean) + CLEAN=true + shift + ;; + --help) + head -25 "$0" | tail -15 + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Clean if requested +if [ "$CLEAN" = true ]; then + echo "Cleaning build directories..." + rm -rf "${CMAKE_BUILD_DIR}" + rm -rf "${SCRIPT_DIR}/build" +fi + +# Build native libraries with CMake +build_native() { + echo "Building native JNI libraries..." + echo " OS/Arch: ${OS_ARCH}" + echo " CMake build dir: ${CMAKE_BUILD_DIR}" + echo " Parallel jobs: ${NPROC}" + + mkdir -p "${CMAKE_BUILD_DIR}" + cd "${CMAKE_BUILD_DIR}" + + # Configure and build ExecutorTorch with Java JNI extension + # Using a single CMake configuration that includes the Java extension + cmake "${EXECUTORCH_ROOT}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_BUILD_JAVA_JNI=ON \ + "$@" + + # Build all targets including executorch_jni + cmake --build . -j"${NPROC}" + + # Copy native library to output directory + mkdir -p "${NATIVE_LIB_DIR}/${OS_ARCH}" + + case "$(uname -s)" in + Darwin*) + find "${CMAKE_BUILD_DIR}" -name "libexecutorch_jni.dylib" -exec cp {} "${NATIVE_LIB_DIR}/${OS_ARCH}/" \; + ;; + Linux*) + find "${CMAKE_BUILD_DIR}" -name "libexecutorch_jni.so" -exec cp {} "${NATIVE_LIB_DIR}/${OS_ARCH}/" \; + ;; + CYGWIN*|MINGW*|MSYS*) + find "${CMAKE_BUILD_DIR}" -name "executorch_jni.dll" -exec cp {} "${NATIVE_LIB_DIR}/${OS_ARCH}/" \; + ;; + esac + + echo "Native library built: ${NATIVE_LIB_DIR}/${OS_ARCH}/" + ls -la "${NATIVE_LIB_DIR}/${OS_ARCH}/" +} + +# Build Java JAR +build_jar() { + echo "Building Java JAR..." + cd "${SCRIPT_DIR}" + + if command -v ./gradlew &> /dev/null; then + ./gradlew build + elif command -v gradle &> /dev/null; then + gradle build + else + echo "Error: Gradle not found. Please install Gradle or use the Gradle wrapper." + exit 1 + fi + + echo "JAR built successfully!" + ls -la "${SCRIPT_DIR}/build/libs/" +} + +# Main +if [ "$JAR_ONLY" = true ]; then + build_jar +elif [ "$CMAKE_ONLY" = true ]; then + build_native +else + build_native + build_jar +fi + +echo "" +echo "Build complete!" +echo "" +echo "To use the library:" +echo " 1. Add ${SCRIPT_DIR}/build/libs/executorch-java-*.jar to your classpath" +echo " 2. Either:" +echo " a. Set java.library.path to ${NATIVE_LIB_DIR}/${OS_ARCH}" +echo " b. Or the JAR will auto-extract native libs at runtime" diff --git a/extension/java/settings.gradle b/extension/java/settings.gradle new file mode 100644 index 00000000000..2d700a16897 --- /dev/null +++ b/extension/java/settings.gradle @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +pluginManagement { + repositories { + gradlePluginPortal() + mavenCentral() + } +} + +rootProject.name = "executorch-java" diff --git a/extension/java/src/main/java/com/facebook/jni/HybridData.java b/extension/java/src/main/java/com/facebook/jni/HybridData.java new file mode 100644 index 00000000000..b75af8bfe43 --- /dev/null +++ b/extension/java/src/main/java/com/facebook/jni/HybridData.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.facebook.jni; + +import com.facebook.jni.annotations.DoNotStrip; + +/** + * HybridData holds a C++ pointer created by fbjni. + * + *

This class is a shim for non-Android platforms. On Android, fbjni provides this class + * directly. For desktop platforms, this provides the same interface so that Java code can work + * unchanged across platforms. + * + *

The actual implementation is provided by fbjni's native code. + */ +@DoNotStrip +public class HybridData { + + @DoNotStrip private long mNativePointer; + + static { + // Ensure native library is loaded before any HybridData operations + org.pytorch.executorch.NativeLibraryLoader.loadLibrary("executorch_jni"); + } + + /** Check if the native object is still valid (not destroyed). */ + public boolean isValid() { + return mNativePointer != 0; + } + + /** + * Explicitly release the C++ object. After calling this, {@link #isValid()} will return false. + * + *

This is safe to call multiple times. + */ + public synchronized native void resetNative(); + + /** + * Release native resources when garbage collected. Users should prefer calling {@link + * #resetNative()} explicitly when the lifecycle is known. + */ + @Override + protected void finalize() throws Throwable { + resetNative(); + super.finalize(); + } +} diff --git a/extension/java/src/main/java/com/facebook/jni/annotations/DoNotStrip.java b/extension/java/src/main/java/com/facebook/jni/annotations/DoNotStrip.java new file mode 100644 index 00000000000..a7241eae025 --- /dev/null +++ b/extension/java/src/main/java/com/facebook/jni/annotations/DoNotStrip.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.facebook.jni.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation that indicates a method, field, or class should not be stripped by ProGuard or + * similar tools. This is typically used for JNI entry points. + * + *

This is a platform-independent version of the annotation. On Android, fbjni provides this + * annotation; for desktop JVMs, we provide our own. + */ +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.CONSTRUCTOR}) +@Retention(RetentionPolicy.CLASS) +public @interface DoNotStrip {} diff --git a/extension/java/src/main/java/org/pytorch/executorch/DType.java b/extension/java/src/main/java/org/pytorch/executorch/DType.java new file mode 100644 index 00000000000..3aca4871d64 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/DType.java @@ -0,0 +1,85 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import org.pytorch.executorch.annotations.Experimental; + +/** + * Codes representing tensor data types. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public enum DType { + // NOTE: "jniCode" must be kept in sync with scalar_type.h. + // NOTE: Never serialize "jniCode", because it can change between releases. + + /** Code for dtype ScalarType::Byte */ + UINT8(0), + /** Code for dtype ScalarType::Char */ + INT8(1), + /** Code for dtype ScalarType::Short */ + INT16(2), + /** Code for dtype ScalarType::Int */ + INT32(3), + /** Code for dtype ScalarType::Long */ + INT64(4), + /** Code for dtype ScalarType::Half */ + HALF(5), + /** Code for dtype ScalarType::Float */ + FLOAT(6), + /** Code for dtype ScalarType::Double */ + DOUBLE(7), + /** Code for dtype ScalarType::ComplexHalf */ + COMPLEX_HALF(8), + /** Code for dtype ScalarType::ComplexFloat */ + COMPLEX_FLOAT(9), + /** Code for dtype ScalarType::ComplexDouble */ + COMPLEX_DOUBLE(10), + /** Code for dtype ScalarType::Bool */ + BOOL(11), + /** Code for dtype ScalarType::QInt8 */ + QINT8(12), + /** Code for dtype ScalarType::QUInt8 */ + QUINT8(13), + /** Code for dtype ScalarType::QInt32 */ + QINT32(14), + /** Code for dtype ScalarType::BFloat16 */ + BFLOAT16(15), + /** Code for dtype ScalarType::QUInt4x2 */ + QINT4X2(16), + /** Code for dtype ScalarType::QUInt2x4 */ + QINT2X4(17), + /** Code for dtype ScalarType::Bits1x8 */ + BITS1X8(18), + /** Code for dtype ScalarType::Bits2x4 */ + BITS2X4(19), + /** Code for dtype ScalarType::Bits4x2 */ + BITS4X2(20), + /** Code for dtype ScalarType::Bits8 */ + BITS8(21), + /** Code for dtype ScalarType::Bits16 */ + BITS16(22), + ; + + final int jniCode; + + DType(int jniCode) { + this.jniCode = jniCode; + } + + public static DType fromJniCode(int jniCode) { + for (DType dtype : values()) { + if (dtype.jniCode == jniCode) { + return dtype; + } + } + throw new IllegalArgumentException("No DType found for jniCode " + jniCode); + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/EValue.java b/extension/java/src/main/java/org/pytorch/executorch/EValue.java new file mode 100644 index 00000000000..ab3b77ff1fb --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/EValue.java @@ -0,0 +1,247 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.annotations.DoNotStrip; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Locale; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Java representation of an ExecuTorch value, which is implemented as tagged union that can be one + * of the supported types: https://pytorch.org/docs/stable/jit.html#types . + * + *

Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. + * + *

{@code EValue} objects are constructed with {@code EValue.from(value)}, {@code + * EValue.tupleFrom(value1, value2, ...)}, {@code EValue.listFrom(value1, value2, ...)}, or one of + * the {@code dict} methods, depending on the key type. + * + *

Data is retrieved from {@code EValue} objects with the {@code toX()} methods. Note that {@code + * str}-type EValues must be extracted with {@link #toStr()}, rather than {@link #toString()}. + * + *

{@code EValue} objects may retain references to objects passed into their constructors, and + * may return references to their internal state from {@code toX()}. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +@DoNotStrip +public class EValue { + private static final int TYPE_CODE_NONE = 0; + + private static final int TYPE_CODE_TENSOR = 1; + private static final int TYPE_CODE_STRING = 2; + private static final int TYPE_CODE_DOUBLE = 3; + private static final int TYPE_CODE_INT = 4; + private static final int TYPE_CODE_BOOL = 5; + + private String[] TYPE_NAMES = { + "None", "Tensor", "String", "Double", "Int", "Bool", + }; + + @DoNotStrip private final int mTypeCode; + @DoNotStrip private Object mData; + + @DoNotStrip + private EValue(int typeCode) { + this.mTypeCode = typeCode; + } + + @DoNotStrip + public boolean isNone() { + return TYPE_CODE_NONE == this.mTypeCode; + } + + @DoNotStrip + public boolean isTensor() { + return TYPE_CODE_TENSOR == this.mTypeCode; + } + + @DoNotStrip + public boolean isBool() { + return TYPE_CODE_BOOL == this.mTypeCode; + } + + @DoNotStrip + public boolean isInt() { + return TYPE_CODE_INT == this.mTypeCode; + } + + @DoNotStrip + public boolean isDouble() { + return TYPE_CODE_DOUBLE == this.mTypeCode; + } + + @DoNotStrip + public boolean isString() { + return TYPE_CODE_STRING == this.mTypeCode; + } + + /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ + @DoNotStrip + public static EValue optionalNone() { + return new EValue(TYPE_CODE_NONE); + } + + /** Creates a new {@code EValue} of type {@code Tensor}. */ + @DoNotStrip + public static EValue from(Tensor tensor) { + final EValue iv = new EValue(TYPE_CODE_TENSOR); + iv.mData = tensor; + return iv; + } + + /** Creates a new {@code EValue} of type {@code bool}. */ + @DoNotStrip + public static EValue from(boolean value) { + final EValue iv = new EValue(TYPE_CODE_BOOL); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code int}. */ + @DoNotStrip + public static EValue from(long value) { + final EValue iv = new EValue(TYPE_CODE_INT); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code double}. */ + @DoNotStrip + public static EValue from(double value) { + final EValue iv = new EValue(TYPE_CODE_DOUBLE); + iv.mData = value; + return iv; + } + + /** Creates a new {@code EValue} of type {@code str}. */ + @DoNotStrip + public static EValue from(String value) { + final EValue iv = new EValue(TYPE_CODE_STRING); + iv.mData = value; + return iv; + } + + @DoNotStrip + public Tensor toTensor() { + preconditionType(TYPE_CODE_TENSOR, mTypeCode); + return (Tensor) mData; + } + + @DoNotStrip + public boolean toBool() { + preconditionType(TYPE_CODE_BOOL, mTypeCode); + return (boolean) mData; + } + + @DoNotStrip + public long toInt() { + preconditionType(TYPE_CODE_INT, mTypeCode); + return (long) mData; + } + + @DoNotStrip + public double toDouble() { + preconditionType(TYPE_CODE_DOUBLE, mTypeCode); + return (double) mData; + } + + @DoNotStrip + public String toStr() { + preconditionType(TYPE_CODE_STRING, mTypeCode); + return (String) mData; + } + + private void preconditionType(int typeCodeExpected, int typeCode) { + if (typeCode != typeCodeExpected) { + throw new IllegalStateException( + String.format( + Locale.US, + "Expected EValue type %s, actual type %s", + getTypeName(typeCodeExpected), + getTypeName(typeCode))); + } + } + + private String getTypeName(int typeCode) { + return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown"; + } + + /** + * Serializes an {@code EValue} into a byte array. Note: This method is experimental and subject + * to change without notice. + * + * @return The serialized byte array. + */ + public byte[] toByteArray() { + if (isNone()) { + return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array(); + } else if (isTensor()) { + Tensor t = toTensor(); + byte[] tByteArray = t.toByteArray(); + return ByteBuffer.allocate(1 + tByteArray.length) + .put((byte) TYPE_CODE_TENSOR) + .put(tByteArray) + .array(); + } else if (isBool()) { + return ByteBuffer.allocate(2) + .put((byte) TYPE_CODE_BOOL) + .put((byte) (toBool() ? 1 : 0)) + .array(); + } else if (isInt()) { + return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array(); + } else if (isDouble()) { + return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array(); + } else if (isString()) { + return ByteBuffer.allocate(1 + toString().length()) + .put((byte) TYPE_CODE_STRING) + .put(toString().getBytes()) + .array(); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + } + + /** + * Deserializes an {@code EValue} from a byte[]. Note: This method is experimental and subject to + * change without notice. + * + * @param bytes The byte array to deserialize from. + * @return The deserialized {@code EValue}. + */ + public static EValue fromByteArray(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.wrap(bytes); + if (buffer == null) { + throw new IllegalArgumentException("buffer cannot be null"); + } + if (!buffer.hasRemaining()) { + throw new IllegalArgumentException("invalid buffer"); + } + int typeCode = buffer.get(); + switch (typeCode) { + case TYPE_CODE_NONE: + return new EValue(TYPE_CODE_NONE); + case TYPE_CODE_TENSOR: + byte[] bufferArray = buffer.array(); + return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length))); + case TYPE_CODE_STRING: + throw new IllegalArgumentException("TYPE_CODE_STRING is not supported"); + case TYPE_CODE_DOUBLE: + return from(buffer.getDouble()); + case TYPE_CODE_INT: + return from(buffer.getLong()); + case TYPE_CODE_BOOL: + return from(buffer.get() != 0); + } + throw new IllegalArgumentException("invalid type code: " + typeCode); + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/java/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java new file mode 100644 index 00000000000..ef793c51a14 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.annotations.DoNotStrip; + +/** Class for entire ExecuTorch Runtime related functions. */ +public class ExecuTorchRuntime { + + static { + NativeLibraryLoader.loadLibrary("executorch_jni"); + } + + private static final ExecuTorchRuntime sInstance = new ExecuTorchRuntime(); + + private ExecuTorchRuntime() {} + + /** Get the runtime instance. */ + public static ExecuTorchRuntime getRuntime() { + return sInstance; + } + + /** Get all registered ops. */ + @DoNotStrip + public static native String[] getRegisteredOps(); + + /** Get all registered backends. */ + @DoNotStrip + public static native String[] getRegisteredBackends(); +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/java/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java new file mode 100644 index 00000000000..102b96ab686 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class ExecutorchRuntimeException extends RuntimeException { + // Error code constants - keep in sync with runtime/core/error.h + // System errors + public static final int OK = 0x00; + public static final int INTERNAL = 0x01; + public static final int INVALID_STATE = 0x02; + public static final int END_OF_METHOD = 0x03; + + // Logical errors + public static final int NOT_SUPPORTED = 0x10; + public static final int NOT_IMPLEMENTED = 0x11; + public static final int INVALID_ARGUMENT = 0x12; + public static final int INVALID_TYPE = 0x13; + public static final int OPERATOR_MISSING = 0x14; + public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; + public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; + + // Resource errors + public static final int NOT_FOUND = 0x20; + public static final int MEMORY_ALLOCATION_FAILED = 0x21; + public static final int ACCESS_FAILED = 0x22; + public static final int INVALID_PROGRAM = 0x23; + public static final int INVALID_EXTERNAL_DATA = 0x24; + public static final int OUT_OF_RESOURCES = 0x25; + + // Delegate errors + public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; + public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; + public static final int DELEGATE_INVALID_HANDLE = 0x32; + + private static final Map ERROR_CODE_MESSAGES; + + static { + Map map = new HashMap<>(); + + // System errors + map.put(OK, "Operation successful"); + map.put(INTERNAL, "Internal error"); + map.put(INVALID_STATE, "Invalid state"); + map.put(END_OF_METHOD, "End of method reached"); + // Logical errors + map.put(NOT_SUPPORTED, "Operation not supported"); + map.put(NOT_IMPLEMENTED, "Operation not implemented"); + map.put(INVALID_ARGUMENT, "Invalid argument"); + map.put(INVALID_TYPE, "Invalid type"); + map.put(OPERATOR_MISSING, "Operator missing"); + map.put(REGISTRATION_EXCEEDING_MAX_KERNELS, "Exceeded max kernels"); + map.put(REGISTRATION_ALREADY_REGISTERED, "Kernel already registered"); + // Resource errors + map.put(NOT_FOUND, "Resource not found"); + map.put(MEMORY_ALLOCATION_FAILED, "Memory allocation failed"); + map.put(ACCESS_FAILED, "Access failed"); + map.put(INVALID_PROGRAM, "Invalid program"); + map.put(INVALID_EXTERNAL_DATA, "Invalid external data"); + map.put(OUT_OF_RESOURCES, "Out of resources"); + // Delegate errors + map.put(DELEGATE_INVALID_COMPATIBILITY, "Delegate invalid compatibility"); + map.put(DELEGATE_MEMORY_ALLOCATION_FAILED, "Delegate memory allocation failed"); + map.put(DELEGATE_INVALID_HANDLE, "Delegate invalid handle"); + ERROR_CODE_MESSAGES = Collections.unmodifiableMap(map); + } + + static class ErrorHelper { + static String formatMessage(int errorCode, String details) { + String baseMessage = ERROR_CODE_MESSAGES.get(errorCode); + if (baseMessage == null) { + baseMessage = "Unknown error code 0x" + Integer.toHexString(errorCode); + } + + String safeDetails = details != null ? details : "No details provided"; + return String.format( + "[Executorch Error 0x%s] %s: %s", + Integer.toHexString(errorCode), baseMessage, safeDetails); + } + + static String getDetailedErrorLogs() { + StringBuilder sb = new StringBuilder(); + try { + String[] logEntries = Module.readLogBufferStatic(); // JNI call + if (logEntries != null && logEntries.length > 0) { + sb.append("\nDetailed logs:\n"); + for (String entry : logEntries) { + sb.append(entry).append("\n"); + } + } + } catch (Exception e) { + sb.append("Failed to retrieve detailed logs: ").append(e.getMessage()); + } + return sb.toString(); + } + } + + private final int errorCode; + + public ExecutorchRuntimeException(int errorCode, String details) { + super(ErrorHelper.formatMessage(errorCode, details)); + this.errorCode = errorCode; + } + + public int getErrorCode() { + return errorCode; + } + + public String getDetailedError() { + return ErrorHelper.getDetailedErrorLogs(); + } + + // Idiomatic Java exception for invalid arguments - extends ExecutorchRuntimeException + public static class ExecutorchInvalidArgumentException extends ExecutorchRuntimeException { + public ExecutorchInvalidArgumentException(String details) { + super(INVALID_ARGUMENT, details); + } + } + + // Factory method to create an exception of the appropriate subclass. + public static RuntimeException makeExecutorchException(int errorCode, String details) { + switch (errorCode) { + case INVALID_ARGUMENT: + return new ExecutorchInvalidArgumentException(details); + default: + return new ExecutorchRuntimeException(errorCode, details); + } + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/MethodMetadata.java b/extension/java/src/main/java/org/pytorch/executorch/MethodMetadata.java new file mode 100644 index 00000000000..b2dde35a2d8 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/MethodMetadata.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +/** Helper class to access the metadata for a method from a Module */ +public class MethodMetadata { + private String mName; + + private String[] mBackends; + + MethodMetadata setName(String name) { + mName = name; + return this; + } + + /** + * @return Method name + */ + public String getName() { + return mName; + } + + MethodMetadata setBackends(String[] backends) { + mBackends = backends; + return this; + } + + /** + * @return Backends used for this method + */ + public String[] getBackends() { + return mBackends; + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/Module.java b/extension/java/src/main/java/org/pytorch/executorch/Module.java new file mode 100644 index 00000000000..3c4a28e5097 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/Module.java @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Java wrapper for ExecuTorch Module. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public class Module { + + private static final Logger LOGGER = Logger.getLogger(Module.class.getName()); + + static { + // Loads libexecutorch_jni from system path or JAR resources + NativeLibraryLoader.loadLibrary("executorch_jni"); + } + + /** Load mode for the module. Load the whole file as a buffer. */ + public static final int LOAD_MODE_FILE = 0; + + /** Load mode for the module. Use mmap to load pages into memory. */ + public static final int LOAD_MODE_MMAP = 1; + + /** Load mode for the module. Use memory locking and handle errors. */ + public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; + + /** Load mode for the module. Use memory locking and ignore errors. */ + public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; + + private final HybridData mHybridData; + + private final Map mMethodMetadata; + + @DoNotStrip + private static native HybridData initHybrid( + String moduleAbsolutePath, int loadMode, int initHybrid); + + private Module(String moduleAbsolutePath, int loadMode, int numThreads) { + ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); + + mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); + + mMethodMetadata = populateMethodMeta(); + } + + Map populateMethodMeta() { + String[] methods = getMethods(); + Map metadata = new HashMap(); + for (int i = 0; i < methods.length; i++) { + String name = methods[i]; + metadata.put(name, new MethodMetadata().setName(name)); + } + + return metadata; + } + + /** Lock protecting the non-thread safe methods in mHybridData. */ + private Lock mLock = new ReentrantLock(); + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param loadMode load mode for the module. See constants in {@link Module}. + * @return new {@link org.pytorch.executorch.Module} object which owns the model module. + */ + public static Module load(final String modelPath, int loadMode) { + return load(modelPath, loadMode, 0); + } + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @param loadMode load mode for the module. See constants in {@link Module}. + * @param numThreads the number of threads to use for inference. A value of 0 defaults to a + * hardware-specific default. + * @return new {@link org.pytorch.executorch.Module} object which owns the model module. + */ + public static Module load(final String modelPath, int loadMode, int numThreads) { + File modelFile = new File(modelPath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path " + modelPath); + } + return new Module(modelPath, loadMode, numThreads); + } + + /** + * Loads a serialized ExecuTorch module from the specified path on the disk to run on CPU. + * + * @param modelPath path to file that contains the serialized ExecuTorch module. + * @return new {@link org.pytorch.executorch.Module} object which owns the model module. + */ + public static Module load(final String modelPath) { + return load(modelPath, LOAD_MODE_FILE); + } + + /** + * Runs the 'forward' method of this module with the specified arguments. + * + * @param inputs arguments for the ExecuTorch module's 'forward' method. Note: if method 'forward' + * requires inputs but no inputs are given, the function will not error out, but run 'forward' + * with sample inputs. + * @return return value from the 'forward' method. + */ + public EValue[] forward(EValue... inputs) { + return execute("forward", inputs); + } + + /** + * Runs the specified method of this module with the specified arguments. + * + * @param methodName name of the ExecuTorch method to run. + * @param inputs arguments that will be passed to ExecuTorch method. + * @return return value from the method. + */ + public EValue[] execute(String methodName, EValue... inputs) { + try { + mLock.lock(); + if (!mHybridData.isValid()) { + LOGGER.log(Level.SEVERE, "Attempt to use a destroyed module"); + return new EValue[0]; + } + return executeNative(methodName, inputs); + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native EValue[] executeNative(String methodName, EValue... inputs); + + /** + * Load a method on this module. This might help with the first time inference performance, + * because otherwise the method is loaded lazily when it's execute. Note: this function is + * synchronous, and will block until the method is loaded. Therefore, it is recommended to call + * this on a background thread. However, users need to make sure that they don't execute before + * this function returns. + * + * @return the Error code if there was an error loading the method + */ + public int loadMethod(String methodName) { + try { + mLock.lock(); + if (!mHybridData.isValid()) { + LOGGER.log(Level.SEVERE, "Attempt to use a destroyed module"); + return 0x2; // InvalidState + } + return loadMethodNative(methodName); + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native int loadMethodNative(String methodName); + + /** + * Returns the names of the backends in a certain method. + * + * @param methodName method name to query + * @return an array of backend name + */ + @DoNotStrip + private native String[] getUsedBackends(String methodName); + + /** + * Returns the names of methods. + * + * @return name of methods in this Module + */ + @DoNotStrip + public native String[] getMethods(); + + /** + * Get the corresponding @MethodMetadata for a method + * + * @param name method name + * @return @MethodMetadata for this method + */ + public MethodMetadata getMethodMetadata(String name) { + if (!mMethodMetadata.containsKey(name)) { + throw new RuntimeException("method " + name + "does not exist for this module"); + } + + MethodMetadata methodMetadata = mMethodMetadata.get(name); + if (methodMetadata != null) { + methodMetadata.setBackends(getUsedBackends(name)); + } + return methodMetadata; + } + + @DoNotStrip + private static native String[] readLogBufferStaticNative(); + + public static String[] readLogBufferStatic() { + return readLogBufferStaticNative(); + } + + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ + public String[] readLogBuffer() { + return readLogBufferNative(); + } + + @DoNotStrip + private native String[] readLogBufferNative(); + + /** + * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. + * + *

Currently for internal (minibench) use only. + * + * @return true if the etdump was successfully written, false otherwise. + */ + @Experimental + @DoNotStrip + public native boolean etdump(); + + /** + * Explicitly destroys the native Module object. Calling this method is not required, as the + * native object will be destroyed when this object is garbage-collected. However, the timing of + * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory + * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + */ + public void destroy() { + if (mLock.tryLock()) { + try { + mHybridData.resetNative(); + } finally { + mLock.unlock(); + } + } else { + LOGGER.log( + Level.WARNING, + "Destroy was called while the module was in use. Resources will not be immediately" + + " released."); + } + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/NativeLibraryLoader.java b/extension/java/src/main/java/org/pytorch/executorch/NativeLibraryLoader.java new file mode 100644 index 00000000000..3f0b375d3d3 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/NativeLibraryLoader.java @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.HashSet; +import java.util.Locale; +import java.util.Set; + +/** + * Utility class to load native libraries for ExecuTorch on desktop platforms (Linux, macOS, + * Windows). + * + *

This class handles loading native libraries either from the system library path or by + * extracting them from the JAR resources to a temporary directory. + */ +public final class NativeLibraryLoader { + + private static final Set loadedLibraries = new HashSet<>(); + private static File tempDir = null; + private static boolean initialized = false; + + private NativeLibraryLoader() {} + + /** + * Load a native library by name. + * + *

First attempts to load from the system library path using {@link System#loadLibrary}. If + * that fails, attempts to extract the library from JAR resources and load it. + * + * @param libraryName the library name without platform-specific prefix/suffix (e.g., + * "executorch_jni" not "libexecutorch_jni.so") + */ + public static synchronized void loadLibrary(String libraryName) { + if (loadedLibraries.contains(libraryName)) { + return; + } + + // First, try to load from system library path + try { + System.loadLibrary(libraryName); + loadedLibraries.add(libraryName); + return; + } catch (UnsatisfiedLinkError e) { + // Fall through to try loading from JAR + } + + // Try to load from JAR resources + String platformLibName = getPlatformLibraryName(libraryName); + String resourcePath = "/native/" + getOsArch() + "/" + platformLibName; + + try { + File libFile = extractLibraryFromResources(resourcePath, platformLibName); + if (libFile != null) { + System.load(libFile.getAbsolutePath()); + loadedLibraries.add(libraryName); + return; + } + } catch (IOException e) { + // Fall through to final error + } + + // Last resort: try system load again to get a useful error message + System.loadLibrary(libraryName); + loadedLibraries.add(libraryName); + } + + /** + * Get the platform-specific library file name. + * + * @param libraryName the base library name + * @return the platform-specific file name (e.g., "libfoo.so" on Linux) + */ + private static String getPlatformLibraryName(String libraryName) { + String osName = System.getProperty("os.name", "").toLowerCase(Locale.ROOT); + + if (osName.contains("mac") || osName.contains("darwin")) { + return "lib" + libraryName + ".dylib"; + } else if (osName.contains("win")) { + return libraryName + ".dll"; + } else { + // Default to Linux/Unix style + return "lib" + libraryName + ".so"; + } + } + + /** + * Get the OS and architecture string for resource paths. + * + * @return a string like "linux-x86_64", "darwin-aarch64", or "windows-x86_64" + */ + private static String getOsArch() { + String osName = System.getProperty("os.name", "").toLowerCase(Locale.ROOT); + String arch = System.getProperty("os.arch", "").toLowerCase(Locale.ROOT); + + String os; + if (osName.contains("mac") || osName.contains("darwin")) { + os = "darwin"; + } else if (osName.contains("win")) { + os = "windows"; + } else { + os = "linux"; + } + + // Normalize architecture names + if (arch.equals("amd64") || arch.equals("x86_64")) { + arch = "x86_64"; + } else if (arch.equals("aarch64") || arch.equals("arm64")) { + arch = "aarch64"; + } + + return os + "-" + arch; + } + + /** + * Extract a library from JAR resources to a temporary file. + * + * @param resourcePath the path within the JAR + * @param fileName the file name to use in the temp directory + * @return the extracted File, or null if resource not found + * @throws IOException if extraction fails + */ + private static File extractLibraryFromResources(String resourcePath, String fileName) + throws IOException { + InputStream in = NativeLibraryLoader.class.getResourceAsStream(resourcePath); + if (in == null) { + return null; + } + + try { + if (tempDir == null) { + tempDir = createTempDirectory(); + } + + File outFile = new File(tempDir, fileName); + try (OutputStream out = new FileOutputStream(outFile)) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + out.write(buffer, 0, bytesRead); + } + } + + // Make the library executable on Unix-like systems + outFile.setExecutable(true); + + // Register for cleanup on JVM exit + outFile.deleteOnExit(); + + return outFile; + } finally { + in.close(); + } + } + + /** + * Create a temporary directory for extracted native libraries. + * + * @return the temporary directory + * @throws IOException if creation fails + */ + private static File createTempDirectory() throws IOException { + File temp = File.createTempFile("executorch_native", ""); + if (!temp.delete()) { + throw new IOException("Failed to delete temp file: " + temp); + } + if (!temp.mkdir()) { + throw new IOException("Failed to create temp directory: " + temp); + } + temp.deleteOnExit(); + return temp; + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/Tensor.java b/extension/java/src/main/java/org/pytorch/executorch/Tensor.java new file mode 100644 index 00000000000..c7028023468 --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/Tensor.java @@ -0,0 +1,1026 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import com.facebook.jni.HybridData; +import com.facebook.jni.annotations.DoNotStrip; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; +import java.util.Arrays; +import java.util.Locale; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.pytorch.executorch.annotations.Experimental; + +/** + * Representation of an ExecuTorch Tensor. Behavior is similar to PyTorch's tensor objects. + * + *

Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, where {@code data} + * can be an array or a direct {@link Buffer} (of the proper subclass). Helper methods are provided + * to allocate buffers properly. + * + *

To access Tensor data, see {@link #dtype()}, {@link #shape()}, and various {@code getDataAs*} + * methods. + * + *

When constructing {@code Tensor} objects with {@code data} as an array, it is not specified + * whether this data is copied or retained as a reference so it is recommended not to modify it + * after constructing. {@code data} passed as a {@link Buffer} is not copied, so it can be modified + * between {@link Module} calls to avoid reallocation. Data retrieved from {@code Tensor} objects + * may be copied or may be a reference to the {@code Tensor}'s internal data buffer. {@code shape} + * is always copied. + * + *

Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +public abstract class Tensor { + private static final Logger LOGGER = Logger.getLogger(Tensor.class.getName()); + + private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; + private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; + private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; + private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative"; + private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER = + "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)"; + private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = + "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; + + @DoNotStrip final long[] shape; + + private static final int BYTE_SIZE_BYTES = 1; + private static final int INT_SIZE_BYTES = 4; + private static final int LONG_SIZE_BYTES = 8; + private static final int HALF_SIZE_BYTES = 2; + private static final int FLOAT_SIZE_BYTES = 4; + private static final int DOUBLE_SIZE_BYTES = 8; + + /** + * Allocates a new direct {@link ByteBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, {@link + * Tensor#fromBlobUnsigned(ByteBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static ByteBuffer allocateByteBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + } + + /** + * Allocates a new direct {@link IntBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(IntBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static IntBuffer allocateIntBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asIntBuffer(); + } + + /** + * Allocates a new direct {@link FloatBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static FloatBuffer allocateFloatBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + } + + /** + * Allocates a new direct {@link LongBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(LongBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static LongBuffer allocateLongBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + } + + /** + * Allocates a new direct {@link ShortBuffer} with native byte order and specified capacity that + * can be used in {@link Tensor#fromBlob(ShortBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static ShortBuffer allocateHalfBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * HALF_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asShortBuffer(); + } + + /** + * Allocates a new direct {@link DoubleBuffer} with native byte order with specified capacity that + * can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static DoubleBuffer allocateDoubleBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of + * bytes. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlobUnsigned(byte[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_uint8(byteBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of + * bytes. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(byte[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_int8(byteBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of + * ints. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(int[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape)); + intBuffer.put(data); + return new Tensor_int32(intBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array + * of floats. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(float[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape)); + floatBuffer.put(data); + return new Tensor_float32(floatBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float16 with specified shape and data as array + * of IEEE-754 half-precision values encoded in {@code short}s. + * + * @param data Tensor elements encoded as 16-bit floats. + * @param shape Tensor shape + */ + public static Tensor fromBlob(short[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ShortBuffer shortBuffer = allocateHalfBuffer((int) numel(shape)); + shortBuffer.put(data); + return new Tensor_float16(shortBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of + * longs. + * + * @param data Tensor elements + * @param shape Tensor shape + */ + public static Tensor fromBlob(long[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape)); + longBuffer.put(data); + return new Tensor_int64(longBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array + * of doubles. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor fromBlob(double[] data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape)); + doubleBuffer.put(data); + return new Tensor_float64(doubleBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_uint8(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(ByteBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int8(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(IntBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int32(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(FloatBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float32(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float16 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements encoded as IEEE-754 half-precision floats. The buffer is used directly without + * copying. + * @param shape Tensor shape + */ + public static Tensor fromBlob(ShortBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float16(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(LongBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int64(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. + * + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + * @param shape Tensor shape + */ + public static Tensor fromBlob(DoubleBuffer data, long[] shape) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float64(data, shape); + } + + /** + * Creates a new Tensor instance with given data-type and all elements initialized to one. + * + * @param shape Tensor shape + * @param dtype Tensor data-type + */ + public static Tensor ones(long[] shape, DType dtype) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + int numElements = (int) numel(shape); + switch (dtype) { + case UINT8: + byte[] uInt8Data = new byte[numElements]; + Arrays.fill(uInt8Data, (byte) 1); + return Tensor.fromBlobUnsigned(uInt8Data, shape); + case INT8: + byte[] int8Data = new byte[numElements]; + Arrays.fill(int8Data, (byte) 1); + return Tensor.fromBlob(int8Data, shape); + case INT32: + int[] int32Data = new int[numElements]; + Arrays.fill(int32Data, 1); + return Tensor.fromBlob(int32Data, shape); + case FLOAT: + float[] float32Data = new float[numElements]; + Arrays.fill(float32Data, 1.0f); + return Tensor.fromBlob(float32Data, shape); + case INT64: + long[] int64Data = new long[numElements]; + Arrays.fill(int64Data, 1L); + return Tensor.fromBlob(int64Data, shape); + case DOUBLE: + double[] float64Data = new double[numElements]; + Arrays.fill(float64Data, 1.0); + return Tensor.fromBlob(float64Data, shape); + default: + throw new IllegalArgumentException( + String.format("Tensor.ones() cannot be used with DType %s", dtype)); + } + } + + /** + * Creates a new Tensor instance with given data-type and all elements initialized to zero. + * + * @param shape Tensor shape + * @param dtype Tensor data-type + */ + public static Tensor zeros(long[] shape, DType dtype) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + int numElements = (int) numel(shape); + switch (dtype) { + case UINT8: + byte[] uInt8Data = new byte[numElements]; + return Tensor.fromBlobUnsigned(uInt8Data, shape); + case INT8: + byte[] int8Data = new byte[numElements]; + return Tensor.fromBlob(int8Data, shape); + case INT32: + int[] int32Data = new int[numElements]; + return Tensor.fromBlob(int32Data, shape); + case FLOAT: + float[] float32Data = new float[numElements]; + return Tensor.fromBlob(float32Data, shape); + case INT64: + long[] int64Data = new long[numElements]; + return Tensor.fromBlob(int64Data, shape); + case DOUBLE: + double[] float64Data = new double[numElements]; + return Tensor.fromBlob(float64Data, shape); + default: + throw new IllegalArgumentException( + String.format("Tensor.zeros() cannot be used with DType %s", dtype)); + } + } + + @DoNotStrip private HybridData mHybridData; + + private Tensor(long[] shape) { + checkShape(shape); + this.shape = Arrays.copyOf(shape, shape.length); + } + + /** Returns the number of elements in this tensor. */ + public long numel() { + return numel(this.shape); + } + + /** Calculates the number of elements in a tensor with the specified shape. */ + public static long numel(long[] shape) { + checkShape(shape); + int result = 1; + for (long s : shape) { + result *= s; + } + return result; + } + + /** Returns the shape of this tensor. (The array is a fresh copy.) */ + public long[] shape() { + return Arrays.copyOf(shape, shape.length); + } + + /** + * @return data type of this tensor. + */ + public abstract DType dtype(); + + // Called from native + @DoNotStrip + int dtypeJniCode() { + return dtype().jniCode; + } + + /** + * @return a Java byte array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int8 tensor. + */ + public byte[] getDataAsByteArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + } + + /** + * @return a Java short array that contains the tensor data interpreted as IEEE-754 half-precision + * bit patterns. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float16 tensor. + */ + public short[] getDataAsShortArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as short array."); + } + + /** + * @return a Java byte array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-uint8 tensor. + */ + public byte[] getDataAsUnsignedByteArray() { + throw new IllegalStateException( + "Tensor of type " + + getClass().getSimpleName() + + " cannot return data as unsigned byte array."); + } + + /** + * @return a Java int array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int32 tensor. + */ + public int[] getDataAsIntArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as int array."); + } + + /** + * @return a Java float array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float32 tensor. + */ + public float[] getDataAsFloatArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); + } + + /** + * @return a Java long array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-int64 tensor. + */ + public long[] getDataAsLongArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as long array."); + } + + /** + * @return a Java double array that contains the tensor data. This may be a copy or reference. + * @throws IllegalStateException if it is called for a non-float64 tensor. + */ + public double[] getDataAsDoubleArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); + } + + @DoNotStrip + Buffer getRawDataBuffer() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); + } + + static class Tensor_uint8 extends Tensor { + private final ByteBuffer data; + + private Tensor_uint8(ByteBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.UINT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsUnsignedByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape)); + } + } + + static class Tensor_int8 extends Tensor { + private final ByteBuffer data; + + private Tensor_int8(ByteBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape)); + } + } + + static class Tensor_int32 extends Tensor { + private final IntBuffer data; + + private Tensor_int32(IntBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT32; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public int[] getDataAsIntArray() { + data.rewind(); + int[] arr = new int[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape)); + } + } + + static class Tensor_float32 extends Tensor { + private final FloatBuffer data; + + Tensor_float32(FloatBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public float[] getDataAsFloatArray() { + data.rewind(); + float[] arr = new float[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public DType dtype() { + return DType.FLOAT; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape)); + } + } + + static class Tensor_float16 extends Tensor { + private final ShortBuffer data; + + private Tensor_float16(ShortBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.HALF; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public short[] getDataAsShortArray() { + data.rewind(); + short[] arr = new short[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public float[] getDataAsFloatArray() { + data.rewind(); + int remaining = data.remaining(); + float[] arr = new float[remaining]; + for (int i = 0; i < remaining; i++) { + arr[i] = halfBitsToFloat(data.get()); + } + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float16)", Arrays.toString(shape)); + } + + private static float halfBitsToFloat(short halfBits) { + int h = halfBits & 0xFFFF; + int sign = (h >>> 15) & 0x1; + int exp = (h >>> 10) & 0x1F; + int mant = h & 0x3FF; + + if (exp == 0) { + if (mant == 0) { + return sign == 0 ? 0.0f : -0.0f; + } + float result = mant * 5.9604645e-8f; // 2^-24 + return sign == 0 ? result : -result; + } else if (exp == 0x1F) { + if (mant == 0) { + return sign == 0 ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY; + } + int bits = (sign << 31) | 0x7f800000 | (mant << 13); + return Float.intBitsToFloat(bits); + } else { + int exp32 = exp + 112; // 127 (float bias) - 15 (half bias) + int bits = (sign << 31) | (exp32 << 23) | (mant << 13); + return Float.intBitsToFloat(bits); + } + } + } + + static class Tensor_int64 extends Tensor { + private final LongBuffer data; + + private Tensor_int64(LongBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.INT64; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public long[] getDataAsLongArray() { + data.rewind(); + long[] arr = new long[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape)); + } + } + + static class Tensor_float64 extends Tensor { + private final DoubleBuffer data; + + private Tensor_float64(DoubleBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public DType dtype() { + return DType.DOUBLE; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public double[] getDataAsDoubleArray() { + data.rewind(); + double[] arr = new double[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape)); + } + } + + static class Tensor_unsupported extends Tensor { + private final ByteBuffer data; + private final DType mDtype; + + private Tensor_unsupported(ByteBuffer data, long[] shape, DType dtype) { + super(shape); + this.data = data; + this.mDtype = dtype; + LOGGER.log( + Level.SEVERE, + toString() + " in Java. Please consider re-export the model with proper return type"); + } + + @Override + public DType dtype() { + return mDtype; + } + + @Override + public String toString() { + return String.format("Unsupported tensor(%s, dtype=%d)", Arrays.toString(shape), this.mDtype); + } + } + + // region checks + private static void checkArgument(boolean expression, String errorMessage, Object... args) { + if (!expression) { + throw new IllegalArgumentException(String.format(Locale.US, errorMessage, args)); + } + } + + private static void checkShape(long[] shape) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + for (int i = 0; i < shape.length; i++) { + checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE); + } + } + + private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) { + final long numel = numel(shape); + checkArgument( + numel == dataCapacity, + "Inconsistent data capacity:%d and shape number elements:%d shape:%s", + dataCapacity, + numel, + Arrays.toString(shape)); + } + + // endregion checks + + // Called from native + @DoNotStrip + private static Tensor nativeNewTensor( + ByteBuffer data, long[] shape, int dtype, HybridData hybridData) { + Tensor tensor = null; + + if (DType.FLOAT.jniCode == dtype) { + tensor = new Tensor_float32(data.asFloatBuffer(), shape); + } else if (DType.HALF.jniCode == dtype) { + tensor = new Tensor_float16(data.asShortBuffer(), shape); + } else if (DType.INT32.jniCode == dtype) { + tensor = new Tensor_int32(data.asIntBuffer(), shape); + } else if (DType.INT64.jniCode == dtype) { + tensor = new Tensor_int64(data.asLongBuffer(), shape); + } else if (DType.DOUBLE.jniCode == dtype) { + tensor = new Tensor_float64(data.asDoubleBuffer(), shape); + } else if (DType.UINT8.jniCode == dtype) { + tensor = new Tensor_uint8(data, shape); + } else if (DType.INT8.jniCode == dtype) { + tensor = new Tensor_int8(data, shape); + } else { + tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype)); + } + tensor.mHybridData = hybridData; + return tensor; + } + + /** + * Serializes a {@code Tensor} into a byte array. Note: This method is experimental and subject to + * change without notice. This does NOT supoprt list type. + * + * @return The serialized byte array. + */ + public byte[] toByteArray() { + int dtypeSize = 0; + byte[] tensorAsByteArray = null; + if (dtype() == DType.UINT8) { + dtypeSize = BYTE_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel()]; + Tensor_uint8 thiz = (Tensor_uint8) this; + ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray()); + } else if (dtype() == DType.INT8) { + dtypeSize = BYTE_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel()]; + Tensor_int8 thiz = (Tensor_int8) this; + ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray()); + } else if (dtype() == DType.HALF) { + dtypeSize = HALF_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_float16 thiz = (Tensor_float16) this; + ByteBuffer.wrap(tensorAsByteArray).asShortBuffer().put(thiz.getDataAsShortArray()); + } else if (dtype() == DType.INT16) { + throw new IllegalArgumentException("DType.INT16 is not supported in Java so far"); + } else if (dtype() == DType.INT32) { + dtypeSize = INT_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_int32 thiz = (Tensor_int32) this; + ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray()); + } else if (dtype() == DType.INT64) { + dtypeSize = LONG_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_int64 thiz = (Tensor_int64) this; + ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray()); + } else if (dtype() == DType.FLOAT) { + dtypeSize = FLOAT_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_float32 thiz = (Tensor_float32) this; + ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray()); + } else if (dtype() == DType.DOUBLE) { + dtypeSize = DOUBLE_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_float64 thiz = (Tensor_float64) this; + ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray()); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + ByteBuffer byteBuffer = + ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel()); + byteBuffer.put((byte) dtype().jniCode); + byteBuffer.put((byte) shape.length); + for (long s : shape) { + byteBuffer.putInt((int) s); + } + byteBuffer.put(tensorAsByteArray); + return byteBuffer.array(); + } + + /** + * Deserializes a {@code Tensor} from a byte[]. Note: This method is experimental and subject to + * change without notice. This does NOT supoprt list type. + * + * @param bytes The byte array to deserialize from. + * @return The deserialized {@code Tensor}. + */ + public static Tensor fromByteArray(byte[] bytes) { + if (bytes == null) { + throw new IllegalArgumentException("bytes cannot be null"); + } + ByteBuffer buffer = ByteBuffer.wrap(bytes); + if (!buffer.hasRemaining()) { + throw new IllegalArgumentException("invalid buffer"); + } + byte dtype = buffer.get(); + byte shapeLength = buffer.get(); + long[] shape = new long[(int) shapeLength]; + long numel = 1; + for (int i = 0; i < shapeLength; i++) { + int dim = buffer.getInt(); + if (dim < 0) { + throw new IllegalArgumentException("invalid shape"); + } + shape[i] = dim; + numel *= dim; + } + if (dtype == DType.UINT8.jniCode) { + return new Tensor_uint8(buffer, shape); + } else if (dtype == DType.INT8.jniCode) { + return new Tensor_int8(buffer, shape); + } else if (dtype == DType.HALF.jniCode) { + return new Tensor_float16(buffer.asShortBuffer(), shape); + } else if (dtype == DType.INT32.jniCode) { + return new Tensor_int32(buffer.asIntBuffer(), shape); + } else if (dtype == DType.INT64.jniCode) { + return new Tensor_int64(buffer.asLongBuffer(), shape); + } else if (dtype == DType.FLOAT.jniCode) { + return new Tensor_float32(buffer.asFloatBuffer(), shape); + } else if (dtype == DType.DOUBLE.jniCode) { + return new Tensor_float64(buffer.asDoubleBuffer(), shape); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + } +} diff --git a/extension/java/src/main/java/org/pytorch/executorch/annotations/Experimental.java b/extension/java/src/main/java/org/pytorch/executorch/annotations/Experimental.java new file mode 100644 index 00000000000..f5f36fc56da --- /dev/null +++ b/extension/java/src/main/java/org/pytorch/executorch/annotations/Experimental.java @@ -0,0 +1,18 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.annotations; + +/** + * This annotation indicates that an API is experimental and may change or be removed at any time. + * It does not provide any guarantees for API stability or backward-compatibility. + * + *

This status is not permanent, and APIs marked with this annotation will need to be either made + * more robust or removed in the future. + */ +public @interface Experimental {} diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 72bddce7b5b..5f0a528c9b4 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -138,34 +138,22 @@ bool validate_cache_params( ET_CHECK_OR_RETURN_FALSE(v_cache.dim() == 4, "v_cache must be a 4D tensor"); + // For ring buffer support, we allow start_pos >= cache_size. + // The cache will wrap around, and the SDPA implementation handles + // the causal masking appropriately. + // We only require that start_pos is non-negative and that a single + // sequence fits within the cache. ET_CHECK_OR_RETURN_FALSE( - start_pos < k_cache.size(1), - "start_pos must be less than key cache at dim 1"); + start_pos >= 0, + "start_pos must be non-negative, got: %" PRId64, + start_pos); ET_CHECK_OR_RETURN_FALSE( - start_pos < v_cache.size(1), - "start_pos must be less than value cache at dim 1"); - - ET_CHECK_OR_RETURN_FALSE( - (start_pos + seq_length) <= k_cache.size(1), - "start_post + seq_length must be less than max seq length supported by key cache." - "start pos: %" PRId64 ", seq_length: %" PRId64 - "." - "key cache size: %zd", - start_pos, + seq_length <= k_cache.size(1), + "seq_length (%" PRId64 ") must be <= cache size (%zd)", seq_length, k_cache.size(1)); - ET_CHECK_OR_RETURN_FALSE( - (start_pos + seq_length) <= v_cache.size(1), - "start_post + seq_length must be less than max seq length supported by key cache." - "start pos: %" PRId64 ", seq_length: %" PRId64 - "." - "value cache size: %zd", - start_pos, - seq_length, - v_cache.size(1)); - // Make sure they are in contiguous dim order ET_CHECK_OR_RETURN_FALSE( is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()), @@ -218,23 +206,72 @@ void update_cache( auto value_strides = projected_value.strides(); ::executorch::aten::StridesType value_batch_dim_stride = value_strides[0]; + // Ring buffer support: wrap start_pos if it exceeds cache size + int64_t cache_seq_len = cache.size(1); + int64_t value_seq_len = projected_value.size(1); + int64_t wrapped_start_pos = start_pos % cache_seq_len; + ::executorch::aten::SizesType num_bytes_to_copy = (projected_value.numel() / projected_value.size(0)) * projected_value.element_size(); for (int64_t batch_line = 0; batch_line < projected_value.size(0); ++batch_line) { - ::executorch::aten::SizesType cache_pos_offset = - (batch_line * cache_batch_dim_stride + - start_pos * cache_seq_dim_stride) * - cache.element_size(); - ::executorch::aten::SizesType value_pos_offset = - (batch_line * value_batch_dim_stride) * cache.element_size(); - - std::memcpy( - (uint8_t*)cache_data + cache_pos_offset, - (uint8_t*)projected_value_data + value_pos_offset, - num_bytes_to_copy); + // Check if we need to handle wrapping + if (wrapped_start_pos + value_seq_len <= cache_seq_len) { + // No wrapping needed - single contiguous copy + ::executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + wrapped_start_pos * cache_seq_dim_stride) * + cache.element_size(); + ::executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)projected_value_data + value_pos_offset, + num_bytes_to_copy); + } else { + // Ring buffer wrapping needed - copy in two parts + // Part 1: from wrapped_start_pos to end of cache + int64_t first_part_len = cache_seq_len - wrapped_start_pos; + // Part 2: from beginning of cache (wrapped around) + int64_t second_part_len = value_seq_len - first_part_len; + + ::executorch::aten::SizesType bytes_per_token = + (projected_value.numel() / (projected_value.size(0) * projected_value.size(1))) * + projected_value.element_size(); + + // Copy first part (wrapped_start_pos to end of cache) + if (first_part_len > 0) { + ::executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + wrapped_start_pos * cache_seq_dim_stride) * + cache.element_size(); + ::executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)projected_value_data + value_pos_offset, + first_part_len * bytes_per_token); + } + + // Copy second part (beginning of cache, wrapped) + if (second_part_len > 0) { + ::executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride) * cache.element_size(); + ::executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride + + first_part_len * value_strides[1]) * + projected_value.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)projected_value_data + value_pos_offset, + second_part_len * bytes_per_token); + } + } } } @@ -403,9 +440,27 @@ Tensor& custom_sdpa_out_impl( ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); - const int64_t num_keys_for_causal_attention = + // For ring buffer mode: cap num_keys_for_causal_attention to KV cache size + // When start_pos + seq_len exceeds cache size, we should only attend to + // the tokens that are actually in the cache (the last cache_size tokens) + int64_t kv_cache_size = k.size(2); // KV cache sequence dimension + int64_t num_keys_for_causal_attention = attn_mask.has_value() ? -1 : start_pos + seq_len; + // In ring buffer mode, the effective number of keys is capped by cache size + bool ring_buffer_active = num_keys_for_causal_attention > kv_cache_size; + if (ring_buffer_active) { + ET_LOG( + Info, + "SDPA: Ring buffer active - start_pos=%" PRId64 " seq_len=%" PRId64 + " num_keys=%" PRId64 " > kv_cache_size=%" PRId64 ", capping to cache size", + start_pos, + seq_len, + num_keys_for_causal_attention, + kv_cache_size); + num_keys_for_causal_attention = kv_cache_size; + } + ET_KERNEL_CHECK( ctx, resize_tensor(output, q.sizes()) == Error::Ok, diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 73c5ccf707f..c50eaef5256 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -853,8 +853,18 @@ void cpu_flash_attention( // but that requires storing attention mask in float as the current // code doesnt support bool attention mask. // However, lets just fix that as well. + // + // Ring buffer support: When start_pos >= kvSize (cache size), the entire + // cache is filled with past tokens. In this case, we should NOT apply + // causal masking because all cached tokens are logically "before" the + // current query position. We skip causal masking entirely. + bool ring_buffer_full = start_pos >= kvSize; + + // For ring buffer mode when cache is full, we attend to all keys + // without causal masking (effectively is_causal = false for this case) + bool apply_causal_mask = is_causal && !ring_buffer_full; int64_t num_keys = - is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize; + apply_causal_mask ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize; int64_t m_start_pos = m + start_pos; auto j_kv = j / num_reps; fill_stub(dst_data, static_cast(0), qSplitSize * headSize); @@ -957,7 +967,8 @@ void cpu_flash_attention( take care of this case because the loop for (int64_t n = 0; n < num_keys; n += kvSplitSize) will exit before that. */ - if (is_causal && m_start_pos <= n + kvSplitSize) { + // Use apply_causal_mask instead of is_causal to skip masking when ring buffer is full + if (apply_causal_mask && m_start_pos <= n + kvSplitSize) { // For this fn to work k_split_size > q_split_size for (int32_t row = 0; row < qBlockSize && (m_start_pos + row < n + (kvSplitSize - 1)); diff --git a/extension/llm/custom_ops/op_update_cache.cpp b/extension/llm/custom_ops/op_update_cache.cpp index 7ab994deb5f..4fabff2ce6e 100644 --- a/extension/llm/custom_ops/op_update_cache.cpp +++ b/extension/llm/custom_ops/op_update_cache.cpp @@ -56,19 +56,16 @@ bool validate_cache_params( indices_tensor.dim_order().data(), indices_tensor.dim()), "indices must be in contiguous dim order"); } else { + // For ring buffer support, we only check that seq_length fits in the cache + // and that start_pos is non-negative. The actual positions will wrap around. ET_CHECK_OR_RETURN_FALSE( - start_pos < quantized_cache.size(1), - "start_pos: %" PRId64 " must be less than cache size at dim 1: %zd", - start_pos, - quantized_cache.size(1)); + start_pos >= 0, + "start_pos must be non-negative, got: %" PRId64, + start_pos); ET_CHECK_OR_RETURN_FALSE( - (start_pos + seq_length) <= quantized_cache.size(1), - "start_post + seq_length must be less than max seq length supported by cache." - "start pos: %" PRId64 ", seq_length: %" PRId64 - "." - "cache size: %zd", - start_pos, + seq_length <= quantized_cache.size(1), + "seq_length (%" PRId64 ") must be <= cache size (%zd)", seq_length, quantized_cache.size(1)); } @@ -187,18 +184,69 @@ Tensor& update_cache_impl( } } else { // Use the original implementation with start_pos + // Support ring buffer by wrapping positions if they exceed cache size + int64_t cache_seq_len = cache.size(1); + int64_t value_seq_len = value.size(1); + + // Wrap start_pos for ring buffer mode + int64_t wrapped_start_pos = start_pos % cache_seq_len; + for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { - executorch::aten::SizesType cache_pos_offset = - (batch_line * cache_batch_dim_stride + - start_pos * cache_seq_dim_stride) * - cache.element_size(); - executorch::aten::SizesType value_pos_offset = - (batch_line * value_batch_dim_stride) * cache.element_size(); - - std::memcpy( - (uint8_t*)cache_data + cache_pos_offset, - (uint8_t*)value_data + value_pos_offset, - num_bytes_to_copy); + // Check if we need to handle wrapping + if (wrapped_start_pos + value_seq_len <= cache_seq_len) { + // No wrapping needed - single contiguous copy + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + wrapped_start_pos * cache_seq_dim_stride) * + cache.element_size(); + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + num_bytes_to_copy); + } else { + // Ring buffer wrapping needed - copy in two parts + // Part 1: from wrapped_start_pos to end of cache + int64_t first_part_len = cache_seq_len - wrapped_start_pos; + // Part 2: from beginning of cache (wrapped around) + int64_t second_part_len = value_seq_len - first_part_len; + + executorch::aten::SizesType bytes_per_token = + (value.numel() / (value.size(0) * value.size(1))) * + value.element_size(); + + // Copy first part (wrapped_start_pos to end of cache) + if (first_part_len > 0) { + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + wrapped_start_pos * cache_seq_dim_stride) * + cache.element_size(); + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + first_part_len * bytes_per_token); + } + + // Copy second part (beginning of cache, wrapped) + if (second_part_len > 0) { + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride) * cache.element_size(); + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride + + first_part_len * value_strides[1]) * + value.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + second_part_len * bytes_per_token); + } + } } } diff --git a/extension/llm/runner/constants.h b/extension/llm/runner/constants.h index d7b36077757..b5dcd1ff49d 100644 --- a/extension/llm/runner/constants.h +++ b/extension/llm/runner/constants.h @@ -19,6 +19,12 @@ inline constexpr auto kVocabSize = "get_vocab_size"; inline constexpr auto kUseKVCache = "use_kv_cache"; inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; +// Ring buffer KV cache configuration +// When enabled, the model uses a ring buffer for KV cache allowing continuous +// generation beyond the initial context length by wrapping positions. +// The sliding window size equals max_context_len for ring buffer models. +inline constexpr auto kIsRingBuffer = "is_ring_buffer"; + // Multimodal method name conventions inline constexpr auto kVisionEncoderMethod = "vision_encoder"; inline constexpr auto kAudioEncoderMethod = "audio_encoder"; diff --git a/extension/llm/runner/multimodal_runner.cpp b/extension/llm/runner/multimodal_runner.cpp index 96d14e2a855..dabc86a2aac 100644 --- a/extension/llm/runner/multimodal_runner.cpp +++ b/extension/llm/runner/multimodal_runner.cpp @@ -51,6 +51,17 @@ MultimodalRunner::MultimodalRunner( stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes(); stats_->gpu_free_before_load_bytes = cuda_memory_tracker_->last_free_bytes(); #endif + + // Initialize ring buffer configuration from metadata + if (metadata_.count(kIsRingBuffer) && metadata_.at(kIsRingBuffer) > 0) { + is_ring_buffer_ = true; + // Sliding window size equals max_context_len for ring buffer models + sliding_window_size_ = metadata_.at(kMaxContextLen); + ET_LOG( + Info, + "Ring buffer KV cache enabled with sliding window size: %" PRId64, + sliding_window_size_); + } } bool MultimodalRunner::is_loaded() { @@ -176,10 +187,48 @@ Error MultimodalRunner::generate( "RSS after multimodal input processing: %f MiB (0 if unsupported)", get_rss_bytes() / 1024.0 / 1024.0); - // Resolve max_new_tokens based on config - int64_t max_context_len = - metadata_.at(kMaxContextLen) - 0; // No start_pos offset - int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_); + // Resolve max_new_tokens based on config and ring buffer mode + int64_t max_context_len = metadata_.at(kMaxContextLen); + int64_t effective_context_len; + int32_t max_new_tokens; + + if (is_ring_buffer_) { + // Ring buffer mode: positions wrap around, allowing continuous generation. + // The model's KV cache uses a ring buffer that overwrites old entries, + // so we can generate beyond the initial context length. + effective_context_len = sliding_window_size_; + + // In ring buffer mode, we're not limited by pos_ since positions wrap + // around. Use the effective position within the sliding window. + int64_t effective_pos = pos_ % sliding_window_size_; + max_new_tokens = + config.resolve_max_new_tokens(effective_context_len, effective_pos); + + // Log ring buffer status + if (pos_ >= sliding_window_size_) { + ET_LOG( + Info, + "Ring buffer active: logical pos %" PRId64 + " >= window size %" PRId64 ", positions will wrap", + pos_, + sliding_window_size_); + } + } else { + // Non-ring buffer mode: original behavior with hard context limit. + effective_context_len = max_context_len - pos_; + + ET_CHECK_OR_RETURN_ERROR( + pos_ < max_context_len, + InvalidArgument, + "pos_ %" PRId64 " >= max_context_len %" PRId64 + ", context exhausted - please increase max context len or enable ring " + "buffer KV cache", + pos_, + max_context_len); + + max_new_tokens = + config.resolve_max_new_tokens(effective_context_len, 0); + } ET_LOG( Info, diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index b34b7b05ce7..38afddabd42 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -161,6 +161,17 @@ class ET_EXPERIMENTAL MultimodalRunner { // Internal state int64_t pos_; + + // Ring buffer configuration for continuous generation beyond context length. + // When is_ring_buffer_ is true, the model's KV cache uses a ring buffer + // that wraps around, allowing generation to continue indefinitely by + // overwriting old cache entries. + bool is_ring_buffer_ = false; + + // The sliding window size for ring buffer models. This is the effective + // context length that the model can attend to at any given time. + // Typically equals max_context_len for ring buffer models. + int64_t sliding_window_size_ = 0; }; } // namespace llm diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index c9b57fb7391..87905387317 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -191,6 +191,18 @@ class RunnerTest : public Test { }; } + std::unordered_map createRingBufferMetadata( + int64_t max_context_len = 128) { + std::unordered_map metadata = { + {"enable_dynamic_shape", false}, + {"get_max_seq_len", max_context_len}, + {"get_max_context_len", max_context_len}, + {"use_kv_cache", true}, + {"is_ring_buffer", 1}, + }; + return metadata; + } + protected: Stats stats_; std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; @@ -355,4 +367,409 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { EXPECT_TRUE(runner.is_loaded()); } +// ============================================================================ +// Ring Buffer Tests +// ============================================================================ + +// Test that ring buffer mode is enabled when metadata contains is_ring_buffer=1 +// and sliding_window_size defaults to max_context_len +TEST_F(RunnerTest, RingBufferModeEnabledFromMetadata) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + // Set up expectations for the text prefiller + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with ring buffer metadata (no explicit sliding_window_size) + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + // Use ring buffer metadata with max_context_len=64 + TextLLMRunner runner( + createRingBufferMetadata(64), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + // Load and generate should succeed + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 5; + config.echo = false; + + Error err = runner.generate("test prompt", config, nullptr); + EXPECT_EQ(err, Error::Ok); +} + +// Test that ring buffer mode works with max_context_len as sliding window size +TEST_F(RunnerTest, RingBufferModeUsesMaxContextLenAsSlidingWindowSize) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + // Set up expectations for the text prefiller + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with ring buffer metadata + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + // max_context_len=128, sliding_window_size will also be 128 + TextLLMRunner runner( + createRingBufferMetadata(128), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + // Load and generate should succeed + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 5; + config.echo = false; + + Error err = runner.generate("test prompt", config, nullptr); + EXPECT_EQ(err, Error::Ok); +} + +// Test that ring buffer mode rejects prompts that exceed sliding window size +TEST_F(RunnerTest, RingBufferModeRejectsPromptExceedingSlidingWindow) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up tokenizer to return a long prompt (10 tokens) + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with ring buffer and small max_context_len=5 + // (which also becomes the sliding_window_size) + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + TextLLMRunner runner( + createRingBufferMetadata(5), // max_context_len=5, sliding_window_size=5 + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 5; + config.echo = false; + + // Generate should fail because prompt (10 tokens) > sliding_window_size (5) + Error err = + runner.generate("long prompt that exceeds window", config, nullptr); + EXPECT_EQ(err, Error::InvalidArgument); +} + +// Test that non-ring buffer mode (default) still works with original behavior +TEST_F(RunnerTest, NonRingBufferModeBackwardCompatibility) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + // Set up expectations for the text prefiller + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner WITHOUT ring buffer (default metadata) + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + TextLLMRunner runner( + createDefaultMetadata(), // No is_ring_buffer + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 5; + config.echo = false; + + // Generate should succeed with default (non-ring buffer) mode + Error err = runner.generate("test prompt", config, nullptr); + EXPECT_EQ(err, Error::Ok); +} + +// Test that non-ring buffer mode rejects prompts exceeding remaining context +TEST_F(RunnerTest, NonRingBufferModeRejectsPromptExceedingRemainingContext) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up tokenizer to return a long prompt (120 tokens) + std::vector long_prompt(120); + for (size_t i = 0; i < 120; i++) { + long_prompt[i] = i + 1; + } + + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault( + [long_prompt](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>(long_prompt); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner WITHOUT ring buffer with max_context_len=128 + // But prompt is 120 tokens, leaving only 8 tokens for generation + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + // Metadata with small max_context_len + std::unordered_map metadata = { + {"enable_dynamic_shape", false}, + {"get_max_seq_len", 50}, + {"get_max_context_len", 50}, // Only 50 tokens allowed + {"use_kv_cache", true}, + }; + + TextLLMRunner runner( + metadata, + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 5; + config.echo = false; + + // Generate should fail because prompt (120 tokens) >= max_context_len (50) + Error err = runner.generate("very long prompt", config, nullptr); + EXPECT_EQ(err, Error::InvalidArgument); +} + +// Test that ring buffer mode allows generation after multiple calls +// (simulating continuous conversation that wraps around) +TEST_F(RunnerTest, RingBufferModeAllowsContinuousGeneration) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + // Set up expectations for the text prefiller + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with ring buffer and small window size + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + // Small sliding_window_size to force wrapping + TextLLMRunner runner( + createRingBufferMetadata(32), // max_context_len=32, sliding_window_size=32 + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 20; + config.echo = false; + + // First generation + Error err1 = runner.generate("first prompt", config, nullptr); + EXPECT_EQ(err1, Error::Ok); + + // Second generation (without reset - should continue in ring buffer mode) + // In ring buffer mode, this should succeed even if pos_ > sliding_window_size + Error err2 = runner.generate("second prompt", config, nullptr); + EXPECT_EQ(err2, Error::Ok); + + // Third generation - positions should wrap around + Error err3 = runner.generate("third prompt", config, nullptr); + EXPECT_EQ(err3, Error::Ok); +} + +// Test that reset() clears position for both ring buffer and non-ring buffer +// modes +TEST_F(RunnerTest, ResetClearsPositionInRingBufferMode) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector&, int64_t&) { + return (Result(4)); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + + TextLLMRunner runner( + createRingBufferMetadata(64), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + runner.load(); + + GenerationConfig config; + config.max_new_tokens = 10; + config.echo = false; + + // Generate to advance position + Error err1 = runner.generate("test prompt", config, nullptr); + EXPECT_EQ(err1, Error::Ok); + + // Reset should clear position + runner.reset(); + + // Generate again - should start from position 0 + Error err2 = runner.generate("another prompt", config, nullptr); + EXPECT_EQ(err2, Error::Ok); +} + } // namespace diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 92dbced9560..507c73a18ab 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -47,6 +47,23 @@ TextLLMRunner::TextLLMRunner( pos_(0) { // Note: This constructor assumes that text_prefiller and text_token_generator // already have references to the Module and TextDecoderRunner they need + + // Initialize ring buffer configuration from metadata + // TODO: Remove this forced enable after testing - currently forcing ring buffer mode + // Original check: if (metadata_.count(kIsRingBuffer) && metadata_.at(kIsRingBuffer) > 0) + { + is_ring_buffer_ = true; + // Sliding window size equals max_context_len for ring buffer models + sliding_window_size_ = metadata_.at(kMaxContextLen); + ET_LOG( + Info, + "Ring buffer KV cache FORCE ENABLED with sliding window size: %" PRId64, + sliding_window_size_); + + // Configure prefiller and token generator for ring buffer mode + text_prefiller_->set_ring_buffer_config(true, sliding_window_size_); + text_token_generator_->set_ring_buffer_config(true, sliding_window_size_); + } } bool TextLLMRunner::is_loaded() const { @@ -129,24 +146,65 @@ Error TextLLMRunner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - // Reduce max_context_len by pos_ - int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); - ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens < max_context_len, - InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId64 - ", Max seq length exceeded - please increase max seq len value in your export script", - num_prompt_tokens, - max_context_len); - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. - int max_new_tokens = - config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + int64_t max_context_len = metadata_.at(kMaxContextLen); + int64_t effective_context_len; + int max_new_tokens; + + if (is_ring_buffer_) { + // Ring buffer mode: positions wrap around, allowing continuous generation. + // The model's KV cache uses a ring buffer that overwrites old entries, + // so we can generate beyond the initial context length. + + // Check that a single prompt fits within the sliding window + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens <= sliding_window_size_, + InvalidArgument, + "num_prompt_tokens %d > sliding_window_size %" PRId64 + ", Single prompt exceeds sliding window capacity", + num_prompt_tokens, + sliding_window_size_); + + // In ring buffer mode, effective context is the sliding window size + // We can always generate up to sliding_window_size tokens at a time + effective_context_len = sliding_window_size_; + + // Calculate max_new_tokens - in ring buffer mode we're not limited by pos_ + // since positions wrap around. Use the config's max or a practical limit. + max_new_tokens = + config.resolve_max_new_tokens(effective_context_len, num_prompt_tokens); + + // Log ring buffer status + if (pos_ >= sliding_window_size_) { + ET_LOG( + Info, + "Ring buffer active: logical pos %" PRId64 + " >= window size %" PRId64 ", positions will wrap", + pos_, + sliding_window_size_); + } + } else { + // Non-ring buffer mode: original behavior with hard context limit. + // Reduce max_context_len by current position. + effective_context_len = max_context_len - pos_; + + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens < effective_context_len, + InvalidArgument, + "num_prompt_tokens %d >= remaining context %" PRId64 + ", Max seq length exceeded - please increase max seq len value in " + "your export script or enable ring buffer KV cache", + num_prompt_tokens, + effective_context_len); + + // Determine max_new_tokens using the GenerationConfig's resolve method + max_new_tokens = config.resolve_max_new_tokens( + effective_context_len, num_prompt_tokens); + } ET_LOG( Info, @@ -175,6 +233,8 @@ Error TextLLMRunner::generate( stats_->first_token_ms = time_in_ms(); stats_->prompt_eval_end_ms = time_in_ms(); + // Note: pos_ is already updated by prefill via the reference parameter + // print the first token from prefill. No prev_token so use cur_token for it. auto decode_result = tokenizer_->decode(cur_token, cur_token); if (!decode_result.ok()) { diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 9dd99d82d59..1113c00251f 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -157,6 +157,17 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { // The position in KV cache of the input, starting from 0. int64_t pos_ = 0; + + // Ring buffer configuration for continuous generation beyond context length. + // When is_ring_buffer_ is true, the model's KV cache uses a ring buffer + // that wraps around, allowing generation to continue indefinitely by + // overwriting old cache entries. + bool is_ring_buffer_ = false; + + // The sliding window size for ring buffer models. This is the effective + // context length that the model can attend to at any given time. + // Typically equals max_context_len for ring buffer models. + int64_t sliding_window_size_ = 0; }; } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index a391cef01de..fc83dad75a4 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -76,6 +76,10 @@ ::executorch::runtime::Result TextPrefiller::prefill_chunk( // When kv cache is not used, start pos is ignored int32_t num_prompt_tokens = prompt_tokens.size(); + // Note: We pass absolute positions to the model. The model's custom ops + // (update_cache, sdpa) will handle cache wrapping internally. + // This is important because RoPE needs absolute positions for correct embeddings. + // store the token uint64_t cur_token; if (enable_parallel_prefill_ || !use_kv_cache_) { diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index a02cd3d1bf4..dd700459419 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -26,6 +26,17 @@ class ET_EXPERIMENTAL TextPrefiller { int64_t max_seq_len = 128); virtual ~TextPrefiller() = default; + + /** + * Configure ring buffer mode for continuous generation past context length. + * @param is_ring_buffer Whether ring buffer KV cache is enabled. + * @param sliding_window_size The size of the sliding window (context length). + */ + void set_ring_buffer_config(bool is_ring_buffer, int64_t sliding_window_size) { + is_ring_buffer_ = is_ring_buffer; + sliding_window_size_ = sliding_window_size; + } + /** * Prefill an LLM Module with the given text input. * @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by @@ -77,6 +88,10 @@ class ET_EXPERIMENTAL TextPrefiller { bool use_kv_cache_; bool enable_parallel_prefill_; int64_t max_seq_len_; + + // Ring buffer configuration + bool is_ring_buffer_ = false; + int64_t sliding_window_size_ = 0; }; } // namespace llm diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 128de05d1d9..23d05b5bd25 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -32,6 +32,16 @@ class ET_EXPERIMENTAL TextTokenGenerator { use_kv_cache_(use_kv_cache), stats_(stats) {} + /** + * Configure ring buffer mode for continuous generation past context length. + * @param is_ring_buffer Whether ring buffer KV cache is enabled. + * @param sliding_window_size The size of the sliding window (context length). + */ + void set_ring_buffer_config(bool is_ring_buffer, int64_t sliding_window_size) { + is_ring_buffer_ = is_ring_buffer; + sliding_window_size_ = sliding_window_size; + } + void set_ignore_eos(bool ignore_eos) { ignore_eos_ = ignore_eos; } @@ -87,6 +97,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { // Generate our tokens while (pos < start_pos + max_new_tokens) { + // Pass absolute position to the model - the model's custom ops will handle + // cache wrapping internally. The absolute position is needed for correct + // RoPE (Rotary Position Embedding) computation. + // Run the model auto logits_res = text_decoder_runner_->step(tokens_managed, pos); @@ -175,6 +189,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { bool use_kv_cache_; bool ignore_eos_ = false; + // Ring buffer configuration + bool is_ring_buffer_ = false; + int64_t sliding_window_size_ = 0; + // state machine bool should_stop_ = false;