Skip to content

Commit

Permalink
Enable ck's grouped convolution fwd instances (#3392)
Browse files Browse the repository at this point in the history
* enable convint8 ck instances

* Fix formatting error.

* add int8 unit tests for group_conv2d_fwd

* wrong naming convention change INT8 to I8

* Delete .gitignore

---------

Co-authored-by: Brian Harrison <[email protected]>
  • Loading branch information
linsun12 and BrianHarrisonAMD authored Nov 20, 2024
1 parent 4282665 commit d16aed9
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 1 addition & 3 deletions driver/conv_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,7 @@ int ConvDriver<Tgpu, Tref>::GetandSetData()
{
out_len[0] *= miopen::deref(inputTensor).GetVectorLength();
}
miopenDataType_t y_type =
(data_type == miopenInt8 || data_type == miopenInt8x4) ? miopenInt32 : data_type;
SetTensorNd(outputTensor, out_len, inflags.GetValueStr("out_layout"), y_type);
SetTensorNd(outputTensor, out_len, inflags.GetValueStr("out_layout"), data_type);
if(inflags.GetValueStr("out_cast_type") != "-1")
{
const auto out_cast_type = DataTypeFromShortString(inflags.GetValueStr("out_cast_type"));
Expand Down
3 changes: 2 additions & 1 deletion src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase
bool IsInt8() const
{
return GetInDataType() == miopenInt8 && GetWeightsDataType() == miopenInt8 &&
(GetOutDataType() == miopenInt32 || GetOutDataType() == miopenFloat);
(GetOutDataType() == miopenInt32 || GetOutDataType() == miopenInt8 ||
GetOutDataType() == miopenFloat);
}
bool IsFp8() const
{
Expand Down
2 changes: 2 additions & 0 deletions src/kernels/gpu_reference_kernel/naive_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,11 +2133,13 @@ inline __device__ void naive_conv_wrw_ndhwc(const src_data_t* __restrict__ p_in,
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, float, double, float)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, half, double, half)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, ushort, double, ushort)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, int8_t, int32_t, int8_t)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, int8_t, int32_t, int32_t)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, int8_t, int32_t, float)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, float, double, float)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, half, double, half)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, ushort, double, ushort)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, int8_t, int32_t, int8_t)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, int8_t, int32_t, int32_t)
DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, int8_t, int32_t, float)

Expand Down
3 changes: 1 addition & 2 deletions test/gtest/group_conv2d_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,4 @@ using namespace group_conv;
DEFINE_GROUP_CONV2D_TEST(float, FP32, Forward);
DEFINE_GROUP_CONV2D_TEST(half, FP16, Forward);
DEFINE_GROUP_CONV2D_TEST(bfloat16, BFP16, Forward);
/// \todo int8_t tests don't work. Need debugging
// DEFINE_GROUP_CONV2D_TEST(int8_t, Forward);
DEFINE_GROUP_CONV2D_TEST(int8_t, I8, Forward);

0 comments on commit d16aed9

Please sign in to comment.