Skip to content

Commit

Permalink
refactoring as_ivy_dev and as_native_dev in all backends
Browse files Browse the repository at this point in the history
  • Loading branch information
abdulasiraj committed Aug 16, 2023
1 parent 728c907 commit e9b54ad
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 65 deletions.
66 changes: 25 additions & 41 deletions ivy/functional/backends/jax/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ def _to_device(x, device=None):
return x


def as_ivy_dev(device: Union[jaxlib.xla_extension.Device, ivy.Device, str], /):
def dev_helper(func_name, device):
device_object = device
if isinstance(device, str):
device_kind, device_id = ivy.parse_device_str(device)
elif isinstance(device, ivy.Device):
elif (isinstance(device, ivy.Device) and func_name == "as_ivy_dev") or (isinstance(device, jaxlib.xla_extension.Device) and func_name == "as_native_dev"):
return device

elif isinstance(device, jaxlib.xla_extension.Device):
device_kind, device_id = (device.platform, device.id)

Expand All @@ -96,60 +97,43 @@ def as_ivy_dev(device: Union[jaxlib.xla_extension.Device, ivy.Device, str], /):
ivy.logging.warning(
f"The device '{device_object}' does not exist on this host. Falling back to '{device_kind}'"
)
device_object = ivy.Device(device_kind)
if func_name == "as_ivy_dev":
device_object = ivy.Device(device_kind)
else:
device_object = jax.devices(device_kind)[0]
else:
ivy.logging.warning(
f"The device '{device_object}' does not exist on this host. Falling back to '{device_kind}:{devices[device_kind][0]}'"
)
device_object = ivy.Device(
device_kind + ":" + str(devices[device_kind][0])
)
if func_name == "as_ivy_dev":
device_object = ivy.Device(
device_kind + ":" + str(devices[device_kind][0])
)
else:
device_object = jax.devices(device_kind)[devices[device_kind][0]]
else:
if device_kind == "cpu":
device_object = ivy.Device(device_kind)
if func_name == "as_ivy_dev":
device_object = ivy.Device(device_kind)
else:
device_object = jax.devices(device_kind)[0]
else:
device_object = ivy.Device(device_kind + ":" + str(device_id))
if func_name == "as_ivy_dev":
device_object = ivy.Device(device_kind + ":" + str(device_id))
else:
device_object = jax.devices(device_kind)[device_id]
return device_object
raise ivy.utils.exceptions.IvyValueError(
f"The device '{device}' does not exist on this host."
)


def as_native_dev(device: Union[jaxlib.xla_extension.Device, ivy.Device, str], /):
device_object = device
if isinstance(device, str):
device_kind, device_id = ivy.parse_device_str(device)
elif isinstance(device, jaxlib.xla_extension.Device):
return device
else:
device_kind, device_id = (ivy.get_device_kind(), ivy.get_device_id())

# Verify if the device exists on the host
available_devices = jax.local_devices()
devices = {}
for i in available_devices:
if i.device_kind not in devices:
devices[i.device_kind] = []
devices[i.device_kind].append(i.id)
def as_ivy_dev(device: Union[jaxlib.xla_extension.Device, ivy.Device, str], /):
return dev_helper("as_ivy_dev", device)

if device_kind in devices:
if device_id not in devices[device_kind]:
if device_kind == "cpu":
ivy.logging.warning(
f"The device '{device_object}' does not exist on this host. Falling back to '{device_kind}'"
)
device_object = jax.devices(device_kind)[0]
else:
ivy.logging.warning(
f"The device '{device_object}' does not exist on this host. Falling back to '{device_kind}:{devices[device_kind][0]}'"
)
device_object = jax.devices(device_kind)[devices[device_kind][0]]
else:
device_object = jax.devices(device_kind)[device_id]
return device_object
raise ivy.utils.exceptions.IvyValueError(
f"The device '{device}' does not exist on this host."
)
def as_native_dev(device: Union[jaxlib.xla_extension.Device, ivy.Device, str], /):
return dev_helper("as_native_dev", device)


def clear_cached_mem_on_dev(device: str, /):
Expand Down
47 changes: 23 additions & 24 deletions ivy/functional/backends/torch/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,48 +45,47 @@ def to_device(
return ret


def as_ivy_dev(device: Union[torch.device, ivy.Device, str], /):
def dev_helper(func_name: str = None, device):
if isinstance(device, str):
dev_type, dev_idx = ivy.parse_device_str(device)
elif isinstance(device, ivy.Device):
elif isinstance(device, ivy.Device) and func_name == "as_ivy_dev":
return device
elif isinstance(device, torch.device) and func_name == "as_native_dev":
return device
else:
dev_type, dev_idx = (device.type, device.index)
if dev_type == "cpu":
return ivy.Device(dev_type)

num_gpus = torch.cuda.device_count()
if dev_type == "cpu":
if func_name == "as_ivy_dev":
return ivy.Device(dev_type)
else:
return torch.device("cpu")
if dev_idx < num_gpus:
return ivy.Device(dev_type + ":" + dev_idx)
if func_name == "as_ivy_dev":
return ivy.Device(dev_type + ":" + dev_idx)
else:
return torch.device(dev_type.replace("gpu", "cuda") + ":" + dev_idx)
else:
ivy.logging.warning(
f"The device '{dev_type}:{dev_idx}' does not exist on this host. Falling back to '{dev_type}:{0}'"
)
return ivy.Device(dev_type.replace("cuda", "gpu") + (":0"))
if func_name == "as_ivy_dev":
return ivy.Device(dev_type.replace("cuda", "gpu") + (":0"))
else:
return torch.device(dev_type.replace("gpu", "cuda") + ":0")


def as_ivy_dev(device: Union[torch.device, ivy.Device, str], /):
return dev_helper("as_ivy_dev", device)


def as_native_dev(
device: Union[torch.device, ivy.Device, str],
/,
) -> Optional[torch.device]:
if isinstance(device, str):
dev_type, dev_idx = ivy.parse_device_str(device)
elif isinstance(device, torch.device):
return device
else:
dev_type, dev_idx = (device.type, device.index)

if dev_type == "cpu":
return torch.device("cpu")

num_gpus = torch.cuda.device_count()
if dev_idx < num_gpus:
return torch.device(dev_type.replace("gpu", "cuda") + ":" + dev_idx)
else:
ivy.logging.warning(
f"The device '{dev_type}:{dev_idx}' does not exist on this host. Falling back to '{dev_type}:{0}'"
)
return torch.device(dev_type.replace("gpu", "cuda") + ":0")
return dev_helper("as_native_dev", device)



def clear_cached_mem_on_dev(device: Union[ivy.Device, torch.device], /) -> None:
Expand Down

0 comments on commit e9b54ad

Please sign in to comment.