Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic "truncate_float" class for bf16 and fp16 quantization #3591

Open
wants to merge 59 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
c51c1ce
first pass at integrating generic float
richagadgil Oct 10, 2024
134b408
fix namespaces
richagadgil Oct 10, 2024
d4fa6eb
fix mantissa
richagadgil Oct 10, 2024
0b60841
refactor
richagadgil Oct 11, 2024
7a646f1
refactor
richagadgil Oct 11, 2024
ebe819b
add fp
richagadgil Oct 11, 2024
379a77a
fixed generic float class
richagadgil Oct 14, 2024
174384c
add fp32 test
richagadgil Oct 14, 2024
787b651
remove import
richagadgil Oct 14, 2024
1d1fa1c
update tests
richagadgil Oct 15, 2024
1791092
fp16 tests that work
richagadgil Oct 17, 2024
a2eb005
update tests
richagadgil Oct 18, 2024
ff8ffc7
updated fp16 and fp32 tests
richagadgil Oct 18, 2024
e36fd65
half tests
richagadgil Oct 22, 2024
9ac4e2a
underflow and overflow tests
richagadgil Oct 22, 2024
f05fd31
generate map
richagadgil Oct 22, 2024
cb4d92d
add more tests
richagadgil Oct 22, 2024
0cc1946
fix names
richagadgil Oct 22, 2024
85a761b
update tests
richagadgil Oct 23, 2024
65cf9ae
remove and
richagadgil Oct 24, 2024
fbabf54
disable warning
richagadgil Oct 24, 2024
549f5e6
fix tidy warning
richagadgil Oct 24, 2024
d302e5d
migraphx py fix
richagadgil Oct 25, 2024
8d475e3
add increments
richagadgil Oct 25, 2024
a0fd055
fix warnings
richagadgil Oct 25, 2024
41379fe
disable duplicate branch warning
richagadgil Oct 25, 2024
0c29c7b
add countzero_std
richagadgil Oct 28, 2024
4b012a8
ci error
richagadgil Oct 28, 2024
dbaa3a8
simplify countl
richagadgil Oct 28, 2024
b2bd2a0
fix ci
richagadgil Oct 28, 2024
6f328f0
src
richagadgil Oct 29, 2024
e6d9763
remove flag
richagadgil Oct 29, 2024
6538050
hide abi warning
richagadgil Oct 29, 2024
4e96d4d
revert changes
richagadgil Oct 29, 2024
ef11f1f
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
e4a25bd
change half in tests
richagadgil Oct 29, 2024
3354c6e
Update generic_float.hpp
richagadgil Oct 29, 2024
6de079b
format
richagadgil Oct 29, 2024
7750874
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
801f485
Merge branch 'develop' into generic_float
causten Oct 30, 2024
33e2c8d
fix bug
richagadgil Oct 30, 2024
9bb7198
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Oct 30, 2024
b3c345d
fix err
richagadgil Oct 30, 2024
03df6f9
edits
richagadgil Oct 31, 2024
ad817b2
tidy and format
richagadgil Oct 31, 2024
898417b
tidy etc
richagadgil Oct 31, 2024
aa5b9c9
gf
richagadgil Oct 31, 2024
6f72370
fix tidy errs
richagadgil Nov 1, 2024
0aab1a0
bf16 changes
richagadgil Nov 4, 2024
7b965c0
add flag to trace quantization passes (#3571)
shivadbhavsar Oct 30, 2024
5f5f13d
bf16
richagadgil Oct 30, 2024
d64b124
Update bf16.cpp
richagadgil Nov 1, 2024
a064eaa
Update bf16.hpp
richagadgil Nov 2, 2024
befbd9e
Update bf16.hpp
richagadgil Nov 2, 2024
08b9511
update files with working version
richagadgil Nov 4, 2024
12cafed
generic class for quant
richagadgil Nov 5, 2024
f604146
format
richagadgil Nov 5, 2024
edc7ccb
Merge branch 'develop' into generic_quant_class
richagadgil Nov 8, 2024
ffec081
Update quantization.cpp
richagadgil Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ add_library(migraphx
propagate_constant.cpp
promote_literals.cpp
quantization.cpp
quantize_fp16.cpp
quantize_int4.cpp
quantize_8bits.cpp
reduce_dims.cpp
Expand All @@ -115,6 +114,7 @@ add_library(migraphx
split_single_dyn_dim.cpp
target.cpp
tmp_dir.cpp
truncate_float.cpp
value.cpp
verify_args.cpp
)
Expand Down
33 changes: 33 additions & 0 deletions src/api/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,16 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
migraphx::quantize_fp16(prog, names);
}

void quantize_bf16_with_op_names(program& prog, std::vector<std::string>& names)
{
if(names.empty())
{
names = {"all"};
}

migraphx::quantize_bf16(prog, names);
}

struct quantize_int8_options
{
std::vector<parameter_map> calibration = {};
Expand Down Expand Up @@ -2199,6 +2209,29 @@ extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
return api_error_result;
}

extern "C" migraphx_status migraphx_quantize_bf16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_bf16_with_op_names((prog->object), (name->object));
});
return api_error_result;
}

extern "C" migraphx_status migraphx_quantize_bf16(migraphx_program_t prog)
{
auto api_error_result = migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_bf16((prog->object));
});
return api_error_result;
}

extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
Expand Down
6 changes: 6 additions & 0 deletions src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(bf16_type, bf16) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
Expand Down Expand Up @@ -602,6 +603,11 @@ migraphx_quantize_fp16_with_op_names(migraphx_program_t prog, migraphx_quantize_

MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);

MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_bf16_with_op_names(migraphx_program_t prog, migraphx_quantize_op_names_t name);

MIGRAPHX_C_EXPORT migraphx_status migraphx_quantize_bf16(migraphx_program_t prog);

MIGRAPHX_C_EXPORT migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);

Expand Down
12 changes: 12 additions & 0 deletions src/api/include/migraphx/migraphx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,18 @@ inline void quantize_fp16(const program& prog)
call(&migraphx_quantize_fp16, prog.get_handle_ptr());
}

/// Quantize program to use bf16
inline void quantize_bf16(const program& prog, const quantize_op_names& names)
{
call(&migraphx_quantize_bf16_with_op_names, prog.get_handle_ptr(), names.get_handle_ptr());
}

/// Quantize program to use bf16
inline void quantize_bf16(const program& prog)
{
call(&migraphx_quantize_bf16, prog.get_handle_ptr());
}

/// Options to be passed when quantizing for int8
struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
Expand Down
8 changes: 8 additions & 0 deletions src/api/migraphx.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ def quantize_op_names(h):
api.add_function('migraphx_quantize_fp16',
api.params(prog='migraphx::program&'),
fname='migraphx::quantize_fp16')
api.add_function('migraphx_quantize_bf16_with_op_names',
api.params(prog='migraphx::program&',
name='std::vector<std::string>&'),
fname='migraphx::quantize_bf16_with_op_names')

api.add_function('migraphx_quantize_bf16',
api.params(prog='migraphx::program&'),
fname='migraphx::quantize_bf16')


@auto_handle()
Expand Down
6 changes: 6 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ struct compiler
compiler_target ct;
compile_options co;
bool to_fp16 = false;
bool to_bf16 = false;
bool to_fp8 = false;
bool to_int8 = false;
bool to_int4 = false;
Expand All @@ -506,6 +507,7 @@ struct compiler
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_bf16, {"--bf16"}, ap.help("Quantize for bf16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
ap(to_int4, {"--int4-weights"}, ap.help("Quantize weights for int4"), ap.set_value(true));
Expand Down Expand Up @@ -555,6 +557,10 @@ struct compiler
{
quantize_fp16(p);
}
if(to_bf16)
{
quantize_bf16(p);
}
if(to_int8)
{
quantize_int8(p, t, {host_params(p)});
Expand Down
105 changes: 105 additions & 0 deletions src/include/migraphx/bf16.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#ifndef MIGRAPHX_GUARD_RTGLIB_BF16_HPP
#define MIGRAPHX_GUARD_RTGLIB_BF16_HPP

#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using bf16 = migraphx::generic_float<7, 8>;

// template <class T>
// using deduce = typename detail::deduce<T>::type;

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

namespace std {

template <class T>
struct common_type<migraphx::bf16, T> : std::common_type<float, T> // NOLINT
{
};

template <class T>
struct common_type<T, migraphx::bf16> : std::common_type<float, T> // NOLINT
{
};

template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::bf16>
{
using type = float;
};

template <>
struct common_type<migraphx::bf16, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e4m3fn, migraphx::bf16>
{
using type = float;
};

template <>
struct common_type<migraphx::bf16, migraphx::fp8::fp8e4m3fn>
{
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e5m2, migraphx::bf16>
{
using type = float;
};

template <>
struct common_type<migraphx::bf16, migraphx::fp8::fp8e5m2>
{
using type = float;
};

template <>
struct common_type<migraphx::bf16, migraphx::bf16>
{
using type = migraphx::bf16;
};

template <>
struct common_type<migraphx::bf16, migraphx::generic_float<10, 5>>
{
using type = float;
};

} // namespace std

#endif
6 changes: 6 additions & 0 deletions src/include/migraphx/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ struct common_type<migraphx::half, migraphx::half>
using type = migraphx::half;
};

template <>
struct common_type<migraphx::half, migraphx::generic_float<7, 8>>
{
using type = float;
};

} // namespace std

#endif
3 changes: 3 additions & 0 deletions src/include/migraphx/quantization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ struct program;
MIGRAPHX_EXPORT void quantize_fp16(program& prog,
const std::vector<std::string>& ins_names = {"all"});

MIGRAPHX_EXPORT void quantize_bf16(program& prog,
const std::vector<std::string>& ins_names = {"all"});

MIGRAPHX_EXPORT void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
Expand All @@ -52,6 +53,7 @@ struct MIGRAPHX_EXPORT shape
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(bool_type, bool) \
m(half_type, half) \
m(bf16_type, bf16) \
m(float_type, float) \
m(double_type, double) \
m(uint8_type, uint8_t) \
Expand All @@ -65,7 +67,7 @@ struct MIGRAPHX_EXPORT shape
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2)
// clang-format on
// clang-format on

#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TRUNCATE_FLOAT_HPP
#define MIGRAPHX_GUARD_RTGLIB_TRUNCATE_FLOAT_HPP

#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -35,12 +36,13 @@ struct program;
struct module;

/**
* quantize a program to fp16
* quantize a program to bf16
*/
struct MIGRAPHX_EXPORT quantize_fp16_pass
struct MIGRAPHX_EXPORT truncate_float_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
shape::type_t float_type;
std::string name() const { return "truncate_float"; }
void apply(module& m) const;
};

Expand Down
5 changes: 5 additions & 0 deletions src/include/migraphx/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>

Expand All @@ -53,6 +54,10 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, bf16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, bf16)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, bf16)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)
Expand Down
4 changes: 4 additions & 0 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
&migraphx::quantize_fp16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});
m.def("quantize_bf16",
&migraphx::quantize_bf16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});
m.def("quantize_int8",
&migraphx::quantize_int8,
py::arg("prog"),
Expand Down
14 changes: 12 additions & 2 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/truncate_float.hpp>
#include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_int4.hpp>
#include <migraphx/simplify_reshapes.hpp>
Expand Down Expand Up @@ -69,7 +69,17 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
quantize_fp16_pass{ins_names},
truncate_float_pass{ins_names, shape::half_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},
quant_tracer());
}

void quantize_bf16(program& prog, const std::vector<std::string>& ins_names)
{
run_passes(prog,
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
truncate_float_pass{ins_names, shape::bf16_type},
optimize_module{{"quantizelinear", "dequantizelinear"}}},
quant_tracer());
}
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::double_type: return rocblas_datatype_f64_r;
case shape::float_type: return rocblas_datatype_f32_r;
case shape::half_type: return rocblas_datatype_f16_r;
case shape::bf16_type: return rocblas_datatype_bf16_r;
case shape::int8_type: return rocblas_datatype_i8_r;
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
Expand Down
Loading
Loading