Skip to content

Commit

Permalink
Reverts ba34084
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627192322
  • Loading branch information
yashk2810 authored and copybara-github committed Apr 22, 2024
1 parent 6066078 commit 4cf2561
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 40 deletions.
1 change: 0 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,6 @@ cc_library(
":dot_as_convolution_util",
":hlo_graph_dumper",
":hlo_pass",
":host_memory_offload_annotations_hdr",
"//xla:array",
"//xla:protobuf_util",
"//xla:shape_tree",
Expand Down
10 changes: 3 additions & 7 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ limitations under the License.
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/protobuf_util.h"
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -278,12 +277,9 @@ bool IsPassthroughCustomOps(const HloInstruction* hlo) {
hlo->operand(0)->shape().rank() != hlo->shape().rank()) {
return false;
}

return hlo->IsCustomCall(
{"ResizeNearest", "ResizeBilinear", "ResizeNearestGrad",
"ResizeBilinearGrad", "Cholesky",
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget,
host_memory_offload_annotations::kMoveToHostCustomCallTarget});
return hlo->IsCustomCall({"ResizeNearest", "ResizeBilinear",
"ResizeNearestGrad", "ResizeBilinearGrad",
"Cholesky"});
}

// Return the operand which is the most suitable for determining the sharding
Expand Down
32 changes: 0 additions & 32 deletions xla/service/sharding_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9374,38 +9374,6 @@ ENTRY %reshape {
EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}"));
}

TEST_F(ShardingPropagationTest, OffloadingPropagation) {
const char* const hlo_string = R"(
HloModule module
ENTRY %offloading {
%param0 = f32[1,256,128] parameter(0), sharding={devices=[1,1,4]0,1,2,3}
%zero = f32[] constant(0.0)
%broadcast = f32[256,256,128] broadcast(%zero), dimensions={}
%izero = s32[] constant(0)
%custom-call.0 = f32[1,256,128] custom-call(f32[1,256,128] %param0), custom_call_target="MoveToHost"
%dynamic-update-slice = f32[256,256,128] dynamic-update-slice(%broadcast, %custom-call.0, %izero, %izero, %izero)
%dynamic-slice = f32[1,256,128] dynamic-slice(%dynamic-update-slice, %izero, %izero, %izero), dynamic_slice_sizes={1,256,128}
%custom-call.1 = f32[1,256,128] custom-call(f32[1,256,128] %dynamic-slice), custom_call_target="MoveToDevice"
ROOT %copy = f32[1,256,128] copy(%custom-call.1), sharding={devices=[1,4,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true)
.Run(module.get()));

XLA_VLOG_LINES(1, module->ToString());
EXPECT_TRUE(changed);

auto* to_host = FindInstruction(module.get(), "custom-call.0");
EXPECT_THAT(to_host, op::Sharding("{devices=[1,1,4]0,1,2,3}"));

auto* from_host_input =
FindInstruction(module.get(), "custom-call.1")->operand(0);
EXPECT_THAT(from_host_input, op::Sharding("{devices=[1,1,4]0,1,2,3}"));
}

TEST_P(ParameterizedMetadataTest, PropagateThroughSingleUsers) {
const char* const hlo_string = R"(
HloModule module
Expand Down

0 comments on commit 4cf2561

Please sign in to comment.