Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
355 changes: 3 additions & 352 deletions include/exec/sequence.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024 NVIDIA Corporation
* Copyright (c) 2026 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
Expand All @@ -15,361 +15,12 @@
*/
#pragma once

#include "../stdexec/__detail/__tuple.hpp"
#include "../stdexec/__detail/__variant.hpp"
#include "../stdexec/execution.hpp"

#include "completion_signatures.hpp"

#include <type_traits>

STDEXEC_PRAGMA_PUSH()
STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
#include "../stdexec/__detail/__sequence.hpp"

namespace experimental::execution
{
namespace _seq
{
template <class... Senders>
struct _sndr;

struct sequence_t
{
template <class Sender>
STDEXEC_ATTRIBUTE(nodiscard, host, device)
constexpr auto operator()(Sender sndr) const
noexcept(STDEXEC::__nothrow_move_constructible<Sender>) -> Sender;

template <class... Senders>
requires(sizeof...(Senders) > 1)
STDEXEC_ATTRIBUTE(nodiscard, host, device)
constexpr auto operator()(Senders... sndrs) const
noexcept(STDEXEC::__nothrow_move_constructible<Senders...>) -> _sndr<Senders...>;
};

template <class Rcvr>
struct _opstate_base
{
template <class... Args>
STDEXEC_ATTRIBUTE(host, device)
constexpr void _set_value([[maybe_unused]] Args&&... args) noexcept
{
STDEXEC::set_value(static_cast<Rcvr&&>(_rcvr), static_cast<Args&&>(args)...);
}

STDEXEC_ATTRIBUTE(host, device)
constexpr void _start_next() noexcept
{
(*_start_next_)(this);
}

Rcvr _rcvr;
void (*_start_next_)(_opstate_base*) noexcept = nullptr;
};

template <class Rcvr>
struct _rcvr_base
{
using receiver_concept = STDEXEC::receiver_t;

template <class Error>
STDEXEC_ATTRIBUTE(host, device)
constexpr void set_error(Error&& err) && noexcept
{
STDEXEC::set_error(static_cast<Rcvr&&>(_opstate->_rcvr), static_cast<Error&&>(err));
}

STDEXEC_ATTRIBUTE(host, device) void set_stopped() && noexcept
{
STDEXEC::set_stopped(static_cast<Rcvr&&>(_opstate->_rcvr));
}

// TODO: use the predecessor's completion scheduler as the current scheduler here.
STDEXEC_ATTRIBUTE(nodiscard, host, device)
constexpr auto get_env() const noexcept -> STDEXEC::env_of_t<Rcvr>
{
return STDEXEC::get_env(_opstate->_rcvr);
}

_opstate_base<Rcvr>* _opstate;
};

template <class Rcvr, bool IsLast>
struct _rcvr : _rcvr_base<Rcvr>
{
using receiver_concept = STDEXEC::receiver_t;

template <class... Args>
STDEXEC_ATTRIBUTE(always_inline, host, device)
constexpr void set_value(Args&&... args) && noexcept
{
if constexpr (IsLast)
{
this->_opstate->_set_value(static_cast<Args&&>(args)...);
}
else
{
this->_opstate->_start_next();
}
}
};

template <class _Tuple>
struct __convert_tuple_fn
{
template <class... _Ts>
STDEXEC_ATTRIBUTE(host, device, always_inline)
constexpr _Tuple operator()(_Ts&&... __ts) const
noexcept(STDEXEC::__nothrow_constructible_from<_Tuple, _Ts...>)
{
return _Tuple{static_cast<_Ts&&>(__ts)...};
}
};

template <class Rcvr, class... Senders>
struct _opstate;

template <class Rcvr, class CvSender0, class... Senders>
struct _opstate<Rcvr, CvSender0, Senders...> : _opstate_base<Rcvr>
{
using operation_state_concept = STDEXEC::operation_state_t;

// We will be connecting the first sender in the opstate constructor, so we don't need to
// store it in the opstate. The use of `STDEXEC::__ignore` causes the first sender to not
// be stored.
using _senders_tuple_t = STDEXEC::__tuple<STDEXEC::__ignore, Senders...>;

template <bool IsLast>
using _rcvr_t = _seq::_rcvr<Rcvr, IsLast>;

template <class Sender, class IsLast>
using _child_opstate_t = STDEXEC::connect_result_t<Sender, _rcvr_t<IsLast::value>>;

using _mk_child_ops_variant_fn =
STDEXEC::__mzip_with2<STDEXEC::__q2<_child_opstate_t>, STDEXEC::__qq<STDEXEC::__variant>>;

using __is_last_mask_t =
STDEXEC::__mfill_c<sizeof...(Senders),
STDEXEC::__mfalse,
STDEXEC::__mbind_back_q<STDEXEC::__mlist, STDEXEC::__mtrue>>;

using _ops_variant_t = STDEXEC::__minvoke<_mk_child_ops_variant_fn,
STDEXEC::__tuple<CvSender0, Senders...>,
__is_last_mask_t>;

template <class CvSndrs>
STDEXEC_ATTRIBUTE(host, device)
constexpr explicit _opstate(Rcvr&& rcvr, CvSndrs&& sndrs)
noexcept(::STDEXEC::__nothrow_applicable<__convert_tuple_fn<_senders_tuple_t>, CvSndrs>
&& ::STDEXEC::__nothrow_connectable<::STDEXEC::__tuple_element_t<0, CvSndrs>,
_rcvr_t<sizeof...(Senders) == 0>>)
: _opstate_base<Rcvr>{static_cast<Rcvr&&>(rcvr)}
// move all but the first sender into the opstate:
, _sndrs{
STDEXEC::__apply(__convert_tuple_fn<_senders_tuple_t>{}, static_cast<CvSndrs&&>(sndrs))}
{
// Below, it looks like we are using `sndrs` after it has been moved from. This is not the
// case. `sndrs` is moved into a tuple type that has `__ignore` for the first element. The
// result is that the first sender in `sndrs` is not moved from, but the rest are.
_ops.template __emplace_from<0>(STDEXEC::connect,
STDEXEC::__get<0>(static_cast<CvSndrs&&>(sndrs)),
_rcvr_t<sizeof...(Senders) == 0>{this});
}

template <std::size_t Remaining>
static constexpr void _start_next(_opstate_base<Rcvr>* _self) noexcept
{
constexpr auto __nth = sizeof...(Senders) - Remaining;
auto* self = static_cast<_opstate*>(_self);
auto& sndr = STDEXEC::__get<__nth + 1>(self->_sndrs);
constexpr bool nothrow =
STDEXEC::__nothrow_connectable<STDEXEC::__m_at_c<__nth, Senders...>,
_rcvr_t<Remaining == 1>>;
STDEXEC_TRY
{
auto& op = self->_ops.template __emplace_from<__nth + 1>(STDEXEC::connect,
std::move(sndr),
_rcvr_t<Remaining == 1>{self});
if constexpr (Remaining > 1)
{
self->_start_next_ = &_start_next<Remaining - 1>;
}
STDEXEC::start(op);
}
STDEXEC_CATCH_ALL
{
if constexpr (nothrow)
{
STDEXEC::__std::unreachable();
}
else
{
STDEXEC::set_error(static_cast<Rcvr&&>(static_cast<_opstate*>(_self)->_rcvr),
std::current_exception());
}
}
}

STDEXEC_ATTRIBUTE(host, device)
constexpr void start() noexcept
{
if (sizeof...(Senders) != 0)
{
this->_start_next_ = &_start_next<sizeof...(Senders)>;
}
STDEXEC::start(STDEXEC::__var::__get<0>(_ops));
}

_senders_tuple_t _sndrs;
_ops_variant_t _ops{STDEXEC::__no_init};
};

template <class Sender>
concept __has_eptr_completion =
STDEXEC::sender_in<Sender>
&& exec::transform_completion_signatures(STDEXEC::get_completion_signatures<Sender>(),
exec::ignore_completion(),
exec::decay_arguments<STDEXEC::set_error_t>(),
exec::ignore_completion())
.__contains(STDEXEC::__fn_ptr_t<STDEXEC::set_error_t, std::exception_ptr>());

template <class Sender0, class... Senders>
struct _sndr<Sender0, Senders...>
{
using sender_concept = STDEXEC::sender_t;

// Even without an Env, we can sometimes still determine the completion signatures
// of the sequence sender. If any of the child senders has a
// set_error(exception_ptr) completion, then the sequence sender has a
// set_error(exception_ptr) completion. We don't have to ask if any connect call
// throws.
template <class Self, class... Env>
requires(sizeof...(Env) > 0)
|| __has_eptr_completion<STDEXEC::__copy_cvref_t<Self, Sender0>>
|| (__has_eptr_completion<Senders> || ...)
STDEXEC_ATTRIBUTE(host, device)
static consteval auto get_completion_signatures()
{
if constexpr (!STDEXEC::__decay_copyable<Self>)
{
return STDEXEC::__throw_compile_time_error<
STDEXEC::_SENDER_TYPE_IS_NOT_DECAY_COPYABLE_,
STDEXEC::_WITH_PRETTY_SENDER_<_sndr<Sender0, Senders...>>>();
}
else
{
using __env_t = STDEXEC::__mfront<Env..., STDEXEC::env<>>;
using __rcvr_t = STDEXEC::__receiver_archetype<__env_t>;
constexpr bool __is_nothrow = (STDEXEC::__nothrow_connectable<Senders, __rcvr_t> && ...);

// The completions of the sequence sender are the error and stopped completions of all the
// child senders plus the value completions of the last child sender.
return exec::concat_completion_signatures(
exec::transform_completion_signatures(
STDEXEC::get_completion_signatures<STDEXEC::__copy_cvref_t<Self, Sender0>, Env...>(),
exec::ignore_completion()),
exec::transform_completion_signatures(
STDEXEC::get_completion_signatures<Senders, Env...>(),
exec::ignore_completion())...,
STDEXEC::get_completion_signatures<STDEXEC::__mback<Senders...>, Env...>(),
STDEXEC::__eptr_completion_unless_t<STDEXEC::__mbool<__is_nothrow>>());
}
}

template <STDEXEC::__decay_copyable Self, class Rcvr>
STDEXEC_ATTRIBUTE(host, device)
constexpr STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Rcvr rcvr)
noexcept(STDEXEC::__nothrow_constructible_from<
_opstate<Rcvr, STDEXEC::__copy_cvref_t<Self, Sender0>, Senders...>,
Rcvr,
decltype((static_cast<Self&&>(self)._sndrs))>)
{
return _opstate<Rcvr, STDEXEC::__copy_cvref_t<Self, Sender0>, Senders...>{
static_cast<Rcvr&&>(rcvr),
static_cast<Self&&>(self)._sndrs};
}
STDEXEC_EXPLICIT_THIS_END(connect)

template <std::size_t Index, class Self>
STDEXEC_ATTRIBUTE(always_inline, host, device)
static constexpr auto&& static_get(Self&& self) noexcept
{
if constexpr (Index == 0)
{
return static_cast<Self&&>(self)._tag;
}
else if constexpr (Index == 1)
{
return static_cast<Self&&>(self)._ign;
}
else
{
return STDEXEC::__get<Index - 2>(static_cast<Self&&>(self)._sndrs);
}
}

template <std::size_t Index>
STDEXEC_ATTRIBUTE(always_inline, host, device)
constexpr auto&& get() && noexcept
{
return static_get<Index>(static_cast<_sndr&&>(*this));
}

template <std::size_t Index>
STDEXEC_ATTRIBUTE(always_inline, host, device)
constexpr auto&& get() & noexcept
{
return static_get<Index>(*this);
}

template <std::size_t Index>
STDEXEC_ATTRIBUTE(always_inline, host, device)
constexpr auto&& get() const & noexcept
{
return static_get<Index>(*this);
}

STDEXEC_ATTRIBUTE(no_unique_address, maybe_unused) sequence_t _tag;
STDEXEC_ATTRIBUTE(no_unique_address, maybe_unused) STDEXEC::__ _ign;
STDEXEC::__tuple<Sender0, Senders...> _sndrs;
};

template <class Sender>
STDEXEC_ATTRIBUTE(host, device)
constexpr auto sequence_t::operator()(Sender sndr) const
noexcept(STDEXEC::__nothrow_move_constructible<Sender>) -> Sender
{
return sndr;
}

template <class... Senders>
requires(sizeof...(Senders) > 1)
STDEXEC_ATTRIBUTE(host, device)
constexpr auto sequence_t::operator()(Senders... sndrs) const
noexcept(STDEXEC::__nothrow_move_constructible<Senders...>) -> _sndr<Senders...>
{
return _sndr<Senders...>{{}, {}, {static_cast<Senders&&>(sndrs)...}};
}
} // namespace _seq

using _seq::sequence_t;
using sequence_t = STDEXEC::__sequence_t;
inline constexpr sequence_t sequence{};
} // namespace experimental::execution

namespace exec = experimental::execution;

namespace std
{
template <class... Senders>
struct tuple_size<exec::_seq::_sndr<Senders...>>
: std::integral_constant<std::size_t, sizeof...(Senders) + 2>
{};

template <size_t I, class... Senders>
struct tuple_element<I, exec::_seq::_sndr<Senders...>>
{
using type = STDEXEC::__m_at_c<I, exec::sequence_t, STDEXEC::__, Senders...>;
};
} // namespace std

STDEXEC_PRAGMA_POP()
2 changes: 1 addition & 1 deletion include/nvexec/stream/repeat_n.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ namespace nv::execution::_strm
}
STDEXEC_CATCH_ALL
{
this->propagate_completion_signal(Tag{}, std::current_exception());
this->propagate_completion_signal(STDEXEC::set_error, std::current_exception());
}
}

Expand Down
Loading