Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for Host code binary operation. #1855

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions include/cutlass/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,100 @@ struct negate<Array<T, N>> {
}
};

/// Fused and-popc-add
template <typename T, int N>
struct and_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {

Array<T, N> result;
and_popc_add<T> scalar_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(a[i], b[i], c[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {

Array<T, N> result;
and_popc_add<T> scalar_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(a[i], scalar, c[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {

Array<T, N> result;
and_popc_add<T> scalar_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(scalar, b[i], c[i]);
}

return result;
}
};

/// Fused xor-popc-add
template <typename T, int N>
struct xor_popc_add<Array<T, N>, Array<T, N>, Array<T, N>> {

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {

Array<T, N> result;
xor_popc_add<T> scalar_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(a[i], b[i], c[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {

Array<T, N> result;
xor_popc_add<T> scalar_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(a[i], scalar, c[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {

Array<T, N> result;
xor_popc_add<T> scalar_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
result[i] = scalar_op(scalar, b[i], c[i]);
}

return result;
}
};

/// Fused multiply-add
template <typename T, int N>
struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
Expand Down
57 changes: 56 additions & 1 deletion include/cutlass/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ struct guarded_multiply_add_relu0<half_t, half_t, half_t> {
}
};

/// Fused multiply-add
/// Fused and-add
template <typename T>
struct and_add {
CUTLASS_HOST_DEVICE
Expand All @@ -589,6 +589,33 @@ struct and_add {
};


/// Fused and-popc-add
template <typename A, typename B = A, typename C = A>
struct and_popc_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
A and_result = a & b;

#if defined(__CUDA__ARCH__)
int popc_result = __popc(and_result);

if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __popc(static_cast<uint32_t>(and_result >> 32));
}

#else
int popc_result = __builtin_popcount(and_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __builtin_popcount(static_cast<uint32_t>(and_result >> 32));
}

#endif

return C(popc_result) + c;

}
};

/// Fused multiply-add
template <typename T>
struct xor_add {
Expand All @@ -598,6 +625,34 @@ struct xor_add {
}
};


/// Fused xor-popc-add
template <typename A, typename B = A, typename C = A>
struct xor_popc_add {
CUTLASS_HOST_DEVICE
C operator()(A const &a, B const &b, C const &c) const {
A and_result = a ^ b;

#if defined(__CUDA__ARCH__)
int popc_result = __popc(and_result);

if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __popc(static_cast<uint32_t>(and_result >> 32));
}

#else
int popc_result = __builtin_popcount(and_result);
if constexpr (sizeof(A) == sizeof(uint64_t)) {
popc_result += __builtin_popcount(static_cast<uint32_t>(and_result >> 32));
}

#endif

return C(popc_result) + c;

}
};

namespace detail {

// Whether namespace-unqualified conj(t) for t of type T is
Expand Down
8 changes: 4 additions & 4 deletions tools/util/include/cutlass/util/reference/host/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");

compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, xor_add<ComputeType>>(
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
}

Expand All @@ -367,7 +367,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");

compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, xor_add<ComputeType>>(
ScalarType, ComputeType, xor_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
}
};
Expand All @@ -389,7 +389,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");

compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, and_add<ComputeType>>(
ScalarType, ComputeType, and_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
}

Expand All @@ -404,7 +404,7 @@ struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
"Tensors must be of rank 2");

compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
ScalarType, ComputeType, and_add<ComputeType>>(
ScalarType, ComputeType, and_popc_add<ComputeType>>(
problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
}
};
Expand Down