Skip to content

Refactor to use thrust::reduce on any. #685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Prev Previous commit
Next Next commit
Clean up
  • Loading branch information
ZelboK committed Jul 30, 2024
commit 15358ec250d6c2f70e0b5cf6775d9b3131073f6a
11 changes: 1 addition & 10 deletions examples/fft_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
#include <cassert>
#include <cstdio>
#include <cuda/std/ccomplex>
#include <thrust/reduce.h>
#include <thrust/functional.h>

using namespace matx;

Expand Down Expand Up @@ -73,11 +71,7 @@ using namespace matx;
*/
int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
{
using T = float;

using OutType = float;
using InType = float;
using FilterType = float;

index_t numSamples = 1;

MATX_ENTER_HANDLER();
Expand All @@ -97,9 +91,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
cudaEventCreate(&start);
cudaEventCreate(&stop);
// Create data objects
tensor_t<InType, 2> inView({batches, numSamples});
tensor_t<InType, 2> outView({batches, numSamples});
tensor_t<InType, 1> solView({numSamples});

// Create time domain buffers
auto sig_time = make_tensor<complex>({batches, signal_size});
Expand Down
7 changes: 1 addition & 6 deletions include/matx/core/operator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ namespace matx {

template <bool ConvertType, typename Func, typename OutputOp, typename InputOp, typename BeginIter, typename EndIter>
__MATX_HOST__ __MATX_INLINE__ auto ReduceOutput(Func &&func, OutputOp &&out, InputOp &&in, BeginIter &&bi, EndIter &&ei) {

if constexpr (remove_cvref_t<decltype(out)>::Rank() <= 1 && is_tensor_view_v<OutputOp>) {
if (out.IsContiguous()) {
if constexpr(ConvertType) {
Expand Down Expand Up @@ -84,17 +83,14 @@ namespace matx {
}
}

auto collapsed = matx::lcollapse<remove_cvref_t<decltype(out)>::Rank()>(rcollapse<remove_cvref_t<decltype(in)>::Rank() -
remove_cvref_t<decltype(out)>::Rank()>(in_base));
auto collapsed = matx::lcollapse<remove_cvref_t<decltype(out)>::Rank()>(rcollapse<remove_cvref_t<decltype(in)>::Rank() - remove_cvref_t<decltype(out)>::Rank()>(in_base));
const auto &iter = matx::RandomOperatorIterator<decltype(collapsed), ConvertType>{collapsed};

return thrust::reduce(iter + *begin, iter + *end, op.Init(), op);
}

template <typename Func, typename OutputOp, typename InputOp, bool ConvertType = true>
__MATX_HOST__ __MATX_INLINE__ auto ReduceInput(Func &&func, OutputOp &&out, InputOp &&in) {
typename detail::base_type_t<InputOp> in_base = in;

if constexpr (in_base.Rank() < 2 && is_tensor_view_v<InputOp>) {
if (in_base.IsContiguous()) {
if constexpr (ConvertType) {
Expand All @@ -118,7 +114,6 @@ namespace matx {
auto collapsed = matx::lcollapse<remove_cvref_t<decltype(out)>::Rank()>(rcollapse<remove_cvref_t<decltype(in)>::Rank() -
remove_cvref_t<decltype(out)>::Rank()>(in_base));
const auto &iter = matx::RandomOperatorIterator<decltype(collapsed), ConvertType>{collapsed};

return ReduceOutput<ConvertType>(std::forward<Func>(func), std::forward<OutputOp>(out), iter, BeginOffset{iter}, EndOffset{iter});
}

Expand Down
4 changes: 1 addition & 3 deletions include/matx/operators/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#pragma once

#include <thrust/reduce.h>
#include <thrust/device_ptr.h>

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
Expand All @@ -42,7 +41,6 @@

namespace matx {


namespace detail {
template<typename OpA, int ORank>
class AnyOp : public BaseOp<AnyOp<OpA, ORank>>
Expand All @@ -69,7 +67,7 @@ namespace detail {
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const {
return tmp_out_(indices...);
};
};

template <typename Out, typename Executor>
void Exec(Out &&out, Executor) const {
Expand Down
114 changes: 55 additions & 59 deletions include/matx/operators/binary_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,79 +40,75 @@
template <typename I1, typename I2, \
typename = typename std::enable_if_t<is_matx_op<I1>() or \
is_matx_op<I2>()>> \
[[nodiscard]] __MATX_INLINE__ auto FUNCTION(I1 i1, I2 i2) \
[[nodiscard]] __MATX_INLINE__ auto FUNCTION(I1 i1, I2 i2) \
{ \
using I1Type = extract_value_type_t<I1>; \
using I2Type = extract_value_type_t<I2>; \
using I1Type = extract_value_type_t<I1>; \
using I2Type = extract_value_type_t<I2>; \
using Op = TENSOR_OP<I1Type, I2Type>; \
const typename detail::base_type<I1>::type &base1 = i1; \
const typename detail::base_type<I2>::type &base2 = i2; \
return detail::matxBinaryOp(base1, base2, Op()); \
const typename detail::base_type<I1>::type &base1 = i1; \
const typename detail::base_type<I2>::type &base2 = i2; \
return detail::matxBinaryOp(base1, base2, Op()); \
}

namespace matx
{
/**
* @brief Utility operator for multiplying scalars by a complex value
*
*
* @tparam T Complex type
* @tparam S Scalar type
* @param n Scalar value
* @param c Complex value
* @return Product result
*/
template <typename T, typename S>
__MATX_INLINE__
typename std::enable_if_t<!std::is_same_v<T, S> && std::is_arithmetic_v<S>,
cuda::std::complex<T>>
__MATX_HOST__ __MATX_DEVICE__ operator*(const cuda::std::complex<T> &c, S n)
{
return c * T(n);
}
__MATX_INLINE__
typename std::enable_if_t<!std::is_same_v<T, S> && std::is_arithmetic_v<S>,
cuda::std::complex<T>>
__MATX_HOST__ __MATX_DEVICE__ operator*(const cuda::std::complex<T> &c, S n)
{
return c * T(n);
}

/**
* @brief Utility operator for multiplying scalars by a complex value
*
*
* @tparam T Complex type
* @tparam S Scalar type
* @param n Scalar value
* @param c Complex value
* @return Product result
*/
template <typename T, typename S>
__MATX_INLINE__
typename std::enable_if_t<!std::is_same_v<T, S> && std::is_arithmetic_v<S>,
cuda::std::complex<T>>
__MATX_HOST__ __MATX_DEVICE__ operator*(S n, const cuda::std::complex<T> &c)
{
return T(n) * c;
}
__MATX_INLINE__
typename std::enable_if_t<!std::is_same_v<T, S> && std::is_arithmetic_v<S>,
cuda::std::complex<T>>
__MATX_HOST__ __MATX_DEVICE__ operator*(S n, const cuda::std::complex<T> &c)
{
return T(n) * c;
}

namespace detail
{
//

namespace detail {
template <class I1, class I2, class Op>
class matxBinaryOp : public BaseOp<matxBinaryOp<I1, I2, Op>>
class matxBinaryOp : public BaseOp<matxBinaryOp<I1,I2,Op>>
{
private:
mutable typename base_type<I1>::type in1_;
mutable typename base_type<I2>::type in2_;
typename base_type<Op>::type op_;

public:
// dummy type to signal this is a matxop
using matxop = bool;
using value_type = typename Op::value_type;
using self_type = matxBinaryOp<I1, I2, Op>;
// using type = in1_;
// using difference_type = in1_;

__MATX_INLINE__ const std::string str() const
{
private:
mutable typename base_type<I1>::type in1_;
mutable typename base_type<I2>::type in2_;
typename base_type<Op>::type op_;

public:
// dummy type to signal this is a matxop
using matxop = bool;
using value_type = typename Op::value_type;
using self_type = matxBinaryOp<I1, I2, Op>;

__MATX_INLINE__ const std::string str() const {
return op_.str(get_type_str(in1_), get_type_str(in2_));
}

__MATX_INLINE__ matxBinaryOp(I1 in1, I2 in2, Op op) : in1_(in1), in2_(in2), op_(op)
__MATX_INLINE__ matxBinaryOp(I1 in1, I2 in2, Op op) : in1_(in1), in2_(in2), op_(op)
{
if constexpr (Rank() > 0)
{
Expand All @@ -132,9 +128,11 @@ namespace matx
template <typename ArrayType, std::enable_if_t<is_std_array_v<ArrayType>, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const ArrayType &idx) const noexcept
{
return cuda::std::apply([&](auto &&...args)
{ return this->operator()(args...); }, idx);
}
return cuda::std::apply([&](auto &&...args) {
return this->operator()(args...);
}, idx);
}


static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
Expand All @@ -145,39 +143,36 @@ namespace matx
{
index_t size1 = detail::get_expanded_size<Rank()>(in1_, dim);
index_t size2 = detail::get_expanded_size<Rank()>(in2_, dim);
return detail::matx_max(size1, size2);
return detail::matx_max(size1,size2);
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun(ShapeType &&shape, Executor &&ex) const noexcept
{
if constexpr (is_matx_op<I1>())
{
if constexpr (is_matx_op<I1>()) {
in1_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<I2>())
{
if constexpr (is_matx_op<I2>()) {
in2_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept
{
if constexpr (is_matx_op<I1>())
{
if constexpr (is_matx_op<I1>()) {
in1_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}

if constexpr (is_matx_op<I2>())
{
if constexpr (is_matx_op<I2>()) {
in2_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}
}
};
}


#ifdef DOXYGEN_ONLY
/**
* Add two operators or tensors
Expand Down Expand Up @@ -242,6 +237,7 @@ namespace matx
*/
Op fmod(Op t, Op t2) {}


/**
* Compute the t^t2 of two operators or tensors
* @param t
Expand Down Expand Up @@ -357,7 +353,7 @@ namespace matx
* @param t2
* RHS tensor or operator input
*/
Op operator|(Op t, Op t2) {}
Op operator|(Op t, Op t2) {}

/**
* Compute t ^ t2 (bitwise XOR) of two operators or tensors
Expand All @@ -366,7 +362,7 @@ namespace matx
* @param t2
* RHS tensor or operator input
*/
Op operator^(Op t, Op t2) {}
Op operator^(Op t, Op t2) {}

/**
* Compute the arctangent of two inputs
Expand All @@ -375,7 +371,7 @@ namespace matx
* @param t2
* Y value of input
*/
Op atan2(Op t, Op t2) {}
Op atan2(Op t, Op t2) {}
#else
DEFINE_BINARY_OP(operator+, detail::AddOp);
DEFINE_BINARY_OP(operator-, detail::SubOp);
Expand Down