Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Take into account frontend supported dtypes in frontend tests #23430

Merged
merged 11 commits into from
Sep 19, 2023
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions ivy_tests/test_ivy/helpers/hypothesis_helpers/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# local
import ivy
from ..pipeline_helper import BackendHandler, get_frontend_config
from ..pipeline_helper import BackendHandler, get_frontend_config, WithBackendContext
from . import number_helpers as nh
from . import array_helpers as ah
from .. import globals as test_globals
Expand Down Expand Up @@ -50,9 +50,13 @@ def _get_type_dict(framework: str, kind: str, is_frontend_test=False):
def _get_type_dict_helper(framework, kind, is_frontend_test):
if is_frontend_test:
framework_module = get_frontend_config(framework).supported_dtypes
return _retrieve_requested_dtypes(framework_module, kind)
else:
framework_module = ivy
with WithBackendContext(framework) as ivy_backend:
return _retrieve_requested_dtypes(ivy_backend, kind)


def _retrieve_requested_dtypes(framework_module, kind):
ReneFabricius marked this conversation as resolved.
Show resolved Hide resolved
if kind == "valid":
return framework_module.valid_dtypes
if kind == "numeric":
Expand Down Expand Up @@ -138,7 +142,7 @@ def get_dtypes(
function as the keyword argument with the given name.
prune_function
if True, the function will prune the data types to only include the ones that
are supported by the current backend. If False, the function will return all
are supported by the current function. If False, the function will return all
the data types supported by the current backend.

Returns
Expand Down Expand Up @@ -225,6 +229,11 @@ def get_dtypes(
# FN_DTYPES & BACKEND_DTYPES & FRONTEND_DTYPES & GROUND_TRUTH_DTYPES

# If being called from a frontend test
if test_globals.CURRENT_FRONTEND is not test_globals._Notsetval:
frontend_dtypes = _get_type_dict_helper(
test_globals.CURRENT_FRONTEND, kind, True
)
valid_dtypes = valid_dtypes.intersection(frontend_dtypes)

# Make sure we return dtypes that are compatible with ground truth backend
ground_truth_is_set = (
Expand Down
Loading