Skip to content

Commit

Permalink
fix bug under sparse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vlan79 authored Sep 25, 2023
1 parent a46885c commit 6400436
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mindspore_rec/ops/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def construct(self, indices):
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
out = self.reshape(weight_flatten, shape)
else:
out = self.embedding_table.get(indices)
shape = self.shape(indices) + (self.embedding_size,)
indices_flatten = self.reshape_first(indices, (-1,))
out = self.map_tensor_get(self.embedding_table, indices_flatten)
out = self.reshape(out, shape)

if self.max_norm is not None:
axis = _make_axis_range(F.rank(indices), F.rank(out))
Expand Down

0 comments on commit 6400436

Please sign in to comment.