Skip to content

Commit 81281c0

Browse files
authored
perf(storage): remove protobuf's copy of data on unmarshalling (#9526)
1 parent a3bb7c0 commit 81281c0

4 files changed

Lines changed: 405 additions & 12 deletions

File tree

storage/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ require (
88
cloud.google.com/go v0.112.1
99
cloud.google.com/go/compute/metadata v0.2.3
1010
cloud.google.com/go/iam v1.1.6
11+
github.com/golang/protobuf v1.5.3
1112
github.com/google/go-cmp v0.6.0
1213
github.com/google/uuid v1.6.0
1314
github.com/googleapis/gax-go/v2 v2.12.2
@@ -26,7 +27,6 @@ require (
2627
github.com/go-logr/logr v1.4.1 // indirect
2728
github.com/go-logr/stdr v1.2.2 // indirect
2829
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
29-
github.com/golang/protobuf v1.5.3 // indirect
3030
github.com/google/martian/v3 v3.3.2 // indirect
3131
github.com/google/s2a-go v0.1.7 // indirect
3232
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect

storage/grpc_client.go

Lines changed: 252 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,18 @@ import (
2727
"cloud.google.com/go/internal/trace"
2828
gapic "cloud.google.com/go/storage/internal/apiv2"
2929
"cloud.google.com/go/storage/internal/apiv2/storagepb"
30+
"github.com/golang/protobuf/proto"
3031
"github.com/googleapis/gax-go/v2"
3132
"google.golang.org/api/googleapi"
3233
"google.golang.org/api/iterator"
3334
"google.golang.org/api/option"
3435
"google.golang.org/api/option/internaloption"
3536
"google.golang.org/grpc"
3637
"google.golang.org/grpc/codes"
38+
"google.golang.org/grpc/encoding"
3739
"google.golang.org/grpc/metadata"
3840
"google.golang.org/grpc/status"
41+
"google.golang.org/protobuf/encoding/protowire"
3942
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
4043
)
4144

@@ -902,12 +905,50 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec
902905
return r, nil
903906
}
904907

908+
// bytesCodec is a grpc codec which permits receiving messages as either
909+
// protobuf messages, or as raw []bytes.
910+
type bytesCodec struct {
911+
encoding.Codec
912+
}
913+
914+
func (bytesCodec) Marshal(v any) ([]byte, error) {
915+
vv, ok := v.(proto.Message)
916+
if !ok {
917+
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
918+
}
919+
return proto.Marshal(vv)
920+
}
921+
922+
func (bytesCodec) Unmarshal(data []byte, v any) error {
923+
switch v := v.(type) {
924+
case *[]byte:
925+
// If gRPC could recycle the data []byte after unmarshaling (through
926+
// buffer pools), we would need to make a copy here.
927+
*v = data
928+
return nil
929+
case proto.Message:
930+
return proto.Unmarshal(data, v)
931+
default:
932+
return fmt.Errorf("can not unmarshal type %T", v)
933+
}
934+
}
935+
936+
func (bytesCodec) Name() string {
937+
// If this isn't "", then gRPC sets the content-subtype of the call to this
938+
// value and we get errors.
939+
return ""
940+
}
941+
905942
func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRangeReaderParams, opts ...storageOption) (r *Reader, err error) {
906943
ctx = trace.StartSpan(ctx, "cloud.google.com/go/storage.grpcStorageClient.NewRangeReader")
907944
defer func() { trace.EndSpan(ctx, err) }()
908945

909946
s := callSettings(c.settings, opts...)
910947

948+
s.gax = append(s.gax, gax.WithGRPCOptions(
949+
grpc.ForceCodec(bytesCodec{}),
950+
))
951+
911952
if s.userProject != "" {
912953
ctx = setUserProjectMetadata(ctx, s.userProject)
913954
}
@@ -923,6 +964,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
923964
req.Generation = params.gen
924965
}
925966

967+
var databuf []byte
968+
926969
// Define a function that initiates a Read with offset and length, assuming
927970
// we have already read seen bytes.
928971
reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) {
@@ -957,12 +1000,23 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
9571000
return err
9581001
}
9591002

960-
msg, err = stream.Recv()
1003+
// Receive the message into databuf as a wire-encoded message so we can
1004+
// use a custom decoder to avoid an extra copy at the protobuf layer.
1005+
err := stream.RecvMsg(&databuf)
9611006
// These types of errors show up on the Recv call, rather than the
9621007
// initialization of the stream via ReadObject above.
9631008
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
9641009
return ErrObjectNotExist
9651010
}
1011+
if err != nil {
1012+
return err
1013+
}
1014+
// Use a custom decoder that uses protobuf unmarshalling for all
1015+
// fields except the checksummed data.
1016+
// Subsequent receives in Read calls will skip all protobuf
1017+
// unmarshalling and directly read the content from the gRPC []byte
1018+
// response, since only the first call will contain other fields.
1019+
msg, err = readFullObjectResponse(databuf)
9661020

9671021
return err
9681022
}, s.retry, s.idempotent)
@@ -1008,6 +1062,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
10081062
leftovers: msg.GetChecksummedData().GetContent(),
10091063
settings: s,
10101064
zeroRange: params.length == 0,
1065+
databuf: databuf,
10111066
},
10121067
}
10131068

@@ -1406,6 +1461,7 @@ type gRPCReader struct {
14061461
stream storagepb.Storage_ReadObjectClient
14071462
reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error)
14081463
leftovers []byte
1464+
databuf []byte
14091465
cancel context.CancelFunc
14101466
settings *settings
14111467
}
@@ -1436,7 +1492,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
14361492
}
14371493

14381494
// Attempt to Recv the next message on the stream.
1439-
msg, err := r.recv()
1495+
content, err := r.recv()
14401496
if err != nil {
14411497
return 0, err
14421498
}
@@ -1448,7 +1504,6 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
14481504
// present in the response here.
14491505
// TODO: Figure out if we need to support decompressive transcoding
14501506
// https://cloud.google.com/storage/docs/transcoding.
1451-
content := msg.GetChecksummedData().GetContent()
14521507
n = copy(p[n:], content)
14531508
leftover := len(content) - n
14541509
if leftover > 0 {
@@ -1471,18 +1526,20 @@ func (r *gRPCReader) Close() error {
14711526
return nil
14721527
}
14731528

1474-
// recv attempts to Recv the next message on the stream. In the event
1475-
// that a retryable error is encountered, the stream will be closed, reopened,
1476-
// and Recv again. This will attempt to Recv until one of the following is true:
1529+
// recv attempts to Recv the next message on the stream and extract the object
1530+
// data that it contains. In the event that a retryable error is encountered,
1531+
// the stream will be closed, reopened, and RecvMsg again.
1532+
// This will attempt to Recv until one of the following is true:
14771533
//
14781534
// * Recv is successful
14791535
// * A non-retryable error is encountered
14801536
// * The Reader's context is canceled
14811537
//
14821538
// The last error received is the one that is returned, which could be from
14831539
// an attempt to reopen the stream.
1484-
func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
1485-
msg, err := r.stream.Recv()
1540+
func (r *gRPCReader) recv() ([]byte, error) {
1541+
err := r.stream.RecvMsg(&r.databuf)
1542+
14861543
var shouldRetry = ShouldRetry
14871544
if r.settings.retry != nil && r.settings.retry.shouldRetry != nil {
14881545
shouldRetry = r.settings.retry.shouldRetry
@@ -1492,10 +1549,195 @@ func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
14921549
// reopen the stream, but will backoff if further attempts are necessary.
14931550
// Reopening the stream Recvs the first message, so if retrying is
14941551
// successful, the next logical chunk will be returned.
1495-
msg, err = r.reopenStream()
1552+
msg, err := r.reopenStream()
1553+
return msg.GetChecksummedData().GetContent(), err
1554+
}
1555+
1556+
if err != nil {
1557+
return nil, err
1558+
}
1559+
1560+
return readObjectResponseContent(r.databuf)
1561+
}
1562+
1563+
// ReadObjectResponse field and subfield numbers.
1564+
const (
1565+
checksummedDataField = protowire.Number(1)
1566+
checksummedDataContentField = protowire.Number(1)
1567+
checksummedDataCRC32CField = protowire.Number(2)
1568+
objectChecksumsField = protowire.Number(2)
1569+
contentRangeField = protowire.Number(3)
1570+
metadataField = protowire.Number(4)
1571+
)
1572+
1573+
// readObjectResponseContent returns the checksummed_data.content field of a
1574+
// ReadObjectResponse message, or an error if the message is invalid.
1575+
// This can be used on recvs of objects after the first recv, since only the
1576+
// first message will contain non-data fields.
1577+
func readObjectResponseContent(b []byte) ([]byte, error) {
1578+
checksummedData, err := readProtoBytes(b, checksummedDataField)
1579+
if err != nil {
1580+
return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err)
1581+
}
1582+
content, err := readProtoBytes(checksummedData, checksummedDataContentField)
1583+
if err != nil {
1584+
return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err)
14961585
}
14971586

1498-
return msg, err
1587+
return content, nil
1588+
}
1589+
1590+
// readFullObjectResponse returns the ReadObjectResponse that is encoded in the
1591+
// wire-encoded message buffer b, or an error if the message is invalid.
1592+
// This must be used on the first recv of an object as it may contain all fields
1593+
// of ReadObjectResponse, and we use or pass on those fields to the user.
1594+
// This function is essentially identical to proto.Unmarshal, except it aliases
1595+
// the data in the input []byte. If the proto library adds a feature to
1596+
// Unmarshal that does that, this function can be dropped.
1597+
func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) {
1598+
msg := &storagepb.ReadObjectResponse{}
1599+
1600+
// Loop over the entire message, extracting fields as we go. This does not
1601+
// handle field concatenation, in which the contents of a single field
1602+
// are split across multiple protobuf tags.
1603+
off := 0
1604+
for off < len(b) {
1605+
// Consume the next tag. This will tell us which field is next in the
1606+
// buffer, its type, and how much space it takes up.
1607+
fieldNum, fieldType, fieldLength := protowire.ConsumeTag(b[off:])
1608+
if fieldLength < 0 {
1609+
return nil, protowire.ParseError(fieldLength)
1610+
}
1611+
off += fieldLength
1612+
1613+
// Unmarshal the field according to its type. Only fields that are not
1614+
// nil will be present.
1615+
switch {
1616+
case fieldNum == checksummedDataField && fieldType == protowire.BytesType:
1617+
// The ChecksummedData field was found. Initialize the struct.
1618+
msg.ChecksummedData = &storagepb.ChecksummedData{}
1619+
1620+
// Get the bytes corresponding to the checksummed data.
1621+
fieldContent, n := protowire.ConsumeBytes(b[off:])
1622+
if n < 0 {
1623+
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", protowire.ParseError(n))
1624+
}
1625+
off += n
1626+
1627+
// Get the nested fields. We need to do this manually as it contains
1628+
// the object content bytes.
1629+
contentOff := 0
1630+
for contentOff < len(fieldContent) {
1631+
gotNum, gotTyp, n := protowire.ConsumeTag(fieldContent[contentOff:])
1632+
if n < 0 {
1633+
return nil, protowire.ParseError(n)
1634+
}
1635+
contentOff += n
1636+
1637+
switch {
1638+
case gotNum == checksummedDataContentField && gotTyp == protowire.BytesType:
1639+
// Get the content bytes.
1640+
bytes, n := protowire.ConsumeBytes(fieldContent[contentOff:])
1641+
if n < 0 {
1642+
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", protowire.ParseError(n))
1643+
}
1644+
msg.ChecksummedData.Content = bytes
1645+
contentOff += n
1646+
case gotNum == checksummedDataCRC32CField && gotTyp == protowire.Fixed32Type:
1647+
v, n := protowire.ConsumeFixed32(fieldContent[contentOff:])
1648+
if n < 0 {
1649+
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", protowire.ParseError(n))
1650+
}
1651+
msg.ChecksummedData.Crc32C = &v
1652+
contentOff += n
1653+
default:
1654+
n = protowire.ConsumeFieldValue(gotNum, gotTyp, fieldContent[contentOff:])
1655+
if n < 0 {
1656+
return nil, protowire.ParseError(n)
1657+
}
1658+
contentOff += n
1659+
}
1660+
}
1661+
case fieldNum == objectChecksumsField && fieldType == protowire.BytesType:
1662+
// The field was found. Initialize the struct.
1663+
msg.ObjectChecksums = &storagepb.ObjectChecksums{}
1664+
1665+
// Get the bytes corresponding to the checksums.
1666+
bytes, n := protowire.ConsumeBytes(b[off:])
1667+
if n < 0 {
1668+
return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", protowire.ParseError(n))
1669+
}
1670+
off += n
1671+
1672+
// Unmarshal.
1673+
if err := proto.Unmarshal(bytes, msg.ObjectChecksums); err != nil {
1674+
return nil, err
1675+
}
1676+
case fieldNum == contentRangeField && fieldType == protowire.BytesType:
1677+
msg.ContentRange = &storagepb.ContentRange{}
1678+
1679+
bytes, n := protowire.ConsumeBytes(b[off:])
1680+
if n < 0 {
1681+
return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", protowire.ParseError(n))
1682+
}
1683+
off += n
1684+
1685+
if err := proto.Unmarshal(bytes, msg.ContentRange); err != nil {
1686+
return nil, err
1687+
}
1688+
case fieldNum == metadataField && fieldType == protowire.BytesType:
1689+
msg.Metadata = &storagepb.Object{}
1690+
1691+
bytes, n := protowire.ConsumeBytes(b[off:])
1692+
if n < 0 {
1693+
return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", protowire.ParseError(n))
1694+
}
1695+
off += n
1696+
1697+
if err := proto.Unmarshal(bytes, msg.Metadata); err != nil {
1698+
return nil, err
1699+
}
1700+
default:
1701+
fieldLength = protowire.ConsumeFieldValue(fieldNum, fieldType, b[off:])
1702+
if fieldLength < 0 {
1703+
return nil, fmt.Errorf("default: %v", protowire.ParseError(fieldLength))
1704+
}
1705+
off += fieldLength
1706+
}
1707+
}
1708+
1709+
return msg, nil
1710+
}
1711+
1712+
// readProtoBytes returns the contents of the protobuf field with number num
1713+
// and type bytes from a wire-encoded message. If the field cannot be found,
1714+
// the returned slice will be nil and no error will be returned.
1715+
//
1716+
// It does not handle field concatenation, in which the contents of a single field
1717+
// are split across multiple protobuf tags. Encoded data containing split fields
1718+
// of this form is technically permissable, but uncommon.
1719+
func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) {
1720+
off := 0
1721+
for off < len(b) {
1722+
gotNum, gotTyp, n := protowire.ConsumeTag(b[off:])
1723+
if n < 0 {
1724+
return nil, protowire.ParseError(n)
1725+
}
1726+
off += n
1727+
if gotNum == num && gotTyp == protowire.BytesType {
1728+
b, n := protowire.ConsumeBytes(b[off:])
1729+
if n < 0 {
1730+
return nil, protowire.ParseError(n)
1731+
}
1732+
return b, nil
1733+
}
1734+
n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:])
1735+
if n < 0 {
1736+
return nil, protowire.ParseError(n)
1737+
}
1738+
off += n
1739+
}
1740+
return nil, nil
14991741
}
15001742

15011743
// reopenStream "closes" the existing stream and attempts to reopen a stream and

0 commit comments

Comments
 (0)