diff --git a/pkg/lvm/nodeserver.go b/pkg/lvm/nodeserver.go index 469a09db..53082c44 100644 --- a/pkg/lvm/nodeserver.go +++ b/pkg/lvm/nodeserver.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "os/exec" + "strconv" "strings" "context" @@ -30,6 +31,7 @@ import ( "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/api/resource" "k8s.io/klog/v2" ) @@ -108,13 +110,9 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis // if ephemeral is specified, create volume here if ephemeralVolume { - val := req.GetVolumeContext()["size"] - if val == "" { - return nil, status.Error(codes.InvalidArgument, "ephemeral inline volume is missing size parameter") - } - size, err := units.RAMInBytes(val) + size, err := parseSize(req.GetVolumeContext()["size"]) if err != nil { - return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("failed to parse size(%s) of ephemeral inline volume: %s", val, err.Error())) + return nil, status.Error(codes.InvalidArgument, err.Error()) } volID := req.GetVolumeId() @@ -124,7 +122,7 @@ func (ns *nodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublis return nil, fmt.Errorf("unable to create vg: %w output:%s", err, output) } - output, err = CreateLVS(ns.vgName, volID, uint64(size), req.GetVolumeContext()["type"]) + output, err = CreateLVS(ns.vgName, volID, size, req.GetVolumeContext()["type"]) if err != nil { return nil, fmt.Errorf("unable to create lv: %w output:%s", err, output) } @@ -331,3 +329,39 @@ func (ns *nodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandV }, nil } + +func parseSize(val string) (uint64, error) { + if val == "" { + return 0, fmt.Errorf("ephemeral inline volume is missing size parameter") + } + + parseWithKubernetes := func(raw string) (uint64, error) { + sizeQuantity, err := resource.ParseQuantity(raw) + if err != nil { + return 0, fmt.Errorf("failed to parse size (%s) of ephemeral inline volume: %w", raw, err) + } + + size, err := strconv.ParseUint(sizeQuantity.AsDec().String(), 10, 64) + if err != nil { + return 0, fmt.Errorf("parsed volume size (%s) of ephemeral inline volume does not fit into an uint64: %w", raw, err) + } + + return size, nil + } + + // this was the initial method to parse the size and has to be kept for compatibility reasons + parseWithGoUnits := func(raw string) (uint64, error) { + size, err := units.RAMInBytes(raw) + if err != nil { + return 0, fmt.Errorf("failed to parse size (%s) of ephemeral inline volume: %w", raw, err) + } + + return uint64(size), nil + } + + if size, err := parseWithKubernetes(val); err == nil { + return size, nil + } + + return parseWithGoUnits(val) +} diff --git a/pkg/lvm/nodeserver_test.go b/pkg/lvm/nodeserver_test.go new file mode 100644 index 00000000..7e8e3314 --- /dev/null +++ b/pkg/lvm/nodeserver_test.go @@ -0,0 +1,53 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package lvm + +import "testing" + +func Test_parseSize(t *testing.T) { + tests := []struct { + name string + val string + want uint64 + wantErr bool + }{ + { + name: "parse size compatible only with k8s", + val: "1Gi", + want: 1 << 30, + wantErr: false, + }, + { + name: "parse size compatible only with go-units", + val: "1GB", + want: 1 << 30, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseSize(tt.val) + if (err != nil) != tt.wantErr { + t.Errorf("parseSize() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseSize() = %v, want %v", got, tt.want) + } + }) + } +}