Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 27 additions & 63 deletions include/xsimd/arch/xsimd_avx2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1225,11 +1225,9 @@ namespace xsimd
__m256i r0 = _mm256_shuffle_epi8(self, half_mask);
__m256i r1 = _mm256_shuffle_epi8(swapped, half_mask);

// select lane by the mask index divided by 16
constexpr auto lane = batch_constant<
uint8_t, A,
00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00,
16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16> {};
// select lane by the mask index divided by 16, first lane is 0, second is 16.
constexpr auto lane_size = make_batch_constant<uint8_t, 16, A>();
constexpr auto lane = (make_iota_batch_constant<uint8_t, A>() / lane_size) * lane_size;
batch_bool<uint8_t, A> blend_mask = (mask & 0b10000u) != lane;
return _mm256_blendv_epi8(r0, r1, blend_mask);
}
Expand Down Expand Up @@ -1259,66 +1257,32 @@ namespace xsimd

namespace detail
{
template <typename T>
constexpr T swizzle_val_none()
{
// Most significant bit of the byte must be 1
return 0x80;
}

template <typename T>
constexpr bool swizzle_val_is_cross_lane(T val, T idx, T size)
{
return (idx < (size / 2)) != (val < (size / 2));
}

template <typename T>
constexpr bool swizzle_val_is_defined(T val, T size)
template <bool cross_batch, typename T, T... Vals>
struct swizzle_mask
{
return (0 <= val) && (val < size);
}

template <typename T>
constexpr T swizzle_self_val(T val, T idx, T size)
{
return (swizzle_val_is_defined(val, size) && !swizzle_val_is_cross_lane(val, idx, size))
? val % (size / 2)
: swizzle_val_none<T>();
}
static constexpr auto values = std::array<T, sizeof...(Vals)> { Vals... };

template <typename T, typename A, T... Vals, std::size_t... Ids>
constexpr batch_constant<T, A, swizzle_self_val(Vals, T(Ids), static_cast<T>(sizeof...(Vals)))...>
swizzle_make_self_batch_impl(std::index_sequence<Ids...>)
{
return {};
}
static constexpr T get(std::size_t idx_, std::size_t size_) noexcept
{
const T size = static_cast<T>(size_);
const T idx = static_cast<T>(idx_);
const T val = values[idx_];

template <typename T, typename A, T... Vals>
constexpr auto swizzle_make_self_batch()
{
return swizzle_make_self_batch_impl<T, A, Vals...>(std::make_index_sequence<sizeof...(Vals)>());
}
// Check if value in bounds
if ((T(0) <= val) && (val < size))
{
// Whether we need to access the value from the other lane
const bool val_is_cross_lane = (idx < (size / 2)) != (val < (size / 2));
if (val_is_cross_lane == cross_batch)
{
return val % (size / 2);
}
}

template <typename T>
constexpr T swizzle_cross_val(T val, T idx, T size)
{
return (swizzle_val_is_defined(val, size) && swizzle_val_is_cross_lane(val, idx, size))
? val % (size / 2)
: swizzle_val_none<T>();
}

template <typename T, typename A, T... Vals, std::size_t... Ids>
constexpr batch_constant<T, A, swizzle_cross_val(Vals, T(Ids), static_cast<T>(sizeof...(Vals)))...>
swizzle_make_cross_batch_impl(std::index_sequence<Ids...>)
{
return {};
}

template <typename T, typename A, T... Vals>
constexpr auto swizzle_make_cross_batch()
{
return swizzle_make_cross_batch_impl<T, A, Vals...>(std::make_index_sequence<sizeof...(Vals)>());
}
// Out of bounds with most significant bit set to 1 will set the swizzle target to 0
return ~T {};
}
};
}

// swizzle (constant mask)
Expand Down Expand Up @@ -1354,8 +1318,8 @@ namespace xsimd

// We can outsmart the dynamic version by creating a compile-time mask that leaves zeros
// where it does not need to select data, resulting in a simple OR merge of the two batches.
constexpr auto self_mask = detail::swizzle_make_self_batch<uint8_t, A, Vals...>();
constexpr auto cross_mask = detail::swizzle_make_cross_batch<uint8_t, A, Vals...>();
constexpr auto self_mask = make_batch_constant<uint8_t, detail::swizzle_mask<false, uint8_t, Vals...>, A>();
constexpr auto cross_mask = make_batch_constant<uint8_t, detail::swizzle_mask<true, uint8_t, Vals...>, A>();

// permute bytes within each lane (AVX2 only)
__m256i r0 = _mm256_shuffle_epi8(self, self_mask.as_batch());
Expand Down