grpc-gatewayでProtocol Buffers over HTTP

grpc-gatewayを使うとgRPCサーバをRESTfulなインターフェースで叩けるようになります。
APIクライアントはswaggerから生成しても良いのですが、Goだとprotocで生成したstructをPOSTする方が依存も少なく、楽なこともあるでしょう。
ということで軽く試してみました。

Protocol Buffersのシリアライズ/デシリアライズ

まず、サーバがProtocol Buffersをやりとりできるようにしてあげないといけません。
そのための機能は全てgrpc-gatewayruntimeパッケージにあるので、以下のようにしてServeMuxを初期化すればOKです。

mux := runtime.NewServeMux(
    runtime.WithMarshalerOption("application/octet-stream", new(runtime.ProtoMarshaller)),
)

APIクライアント

これでクライアントからapplication/octet-streamでProtocol Buffersを投げることができるようになりました。

func do(url string, in *pb.ExampleRequest) (*pb.ExampleResponse, error) {
    body := new(bytes.Buffer)
    if err := new(runtime.ProtoMarshaller).NewEncoder(body).Encode(in); err != nil {
        return nil, err
    }

    res, err := http.Post(url, "application/octet-stream", body)
    if err != nil {
        return nil, err
    }
    defer res.Body.Close()

    if res.StatusCode >= 400 {
        b, err := ioutil.ReadAll(res.Body)
        if err != nil {
            return nil, errors.New(res.Status)
        }
        return nil, fmt.Errorf("%s: %s", res.Status, b)
    }

    out := new(pb.ExampleResponse)
    if err := new(runtime.ProtoMarshaller).NewDecoder(res.Body).Decode(out); err != nil {
        return nil, err
    }
    return out, err
}

エラーの型を共通化する

このままでも普通に使う分には問題ありませんが、サーバがエラーレスポンスを返した際もBodyをstructにデコードできるとより嬉しいです。
現状だと、サーバ側のエラーレスポンスはデフォルトで以下のunexportedな型になっています。

type errorBody struct {
    Error   string     `protobuf:"bytes,1,name=error" json:"error"`
    // This is to make the error more compatible with users that expect errors to be Status objects:
    // https://github.com/grpc/grpc/blob/master/src/proto/grpc/status/status.proto
    // It should be the exact same message as the Error field.
    Message string     `protobuf:"bytes,1,name=message" json:"message"`
    Code    int32      `protobuf:"varint,2,name=code" json:"code"`
    Details []*any.Any `protobuf:"bytes,3,rep,name=details" json:"details,omitempty"`
}

// Make this also conform to proto.Message for builtin JSONPb Marshaler
func (e *errorBody) Reset()         { *e = errorBody{} }
func (e *errorBody) String() string { return proto.CompactTextString(e) }
func (*errorBody) ProtoMessage()    {}

そのため、クライアント側でデコードするためにはこちらをコピーして使うか、サーバ側で型を変更する必要があります。

エラーの型を変える場合はWithProtoErrorHandlerを使います。

mux := runtime.NewServeMux(
    runtime.WithMarshalerOption("application/octet-stream", new(runtime.ProtoMarshaller)),
    runtime.WithProtoErrorHandler(func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) {
        w.Header().Set("Content-Type", marshaler.ContentType())

        s, ok := status.FromError(err)
        if !ok {
            s = status.New(codes.Unknown, err.Error())
        }

        buf, merr := marshaler.Marshal(s.Proto())
        if merr != nil {
            w.WriteHeader(http.StatusInternalServerError)
            io.WriteString(w, `{"error": "failed to marshal error message"}`)
            return
        }

        w.WriteHeader(runtime.HTTPStatusFromCode(s.Code()))
        w.Write(buf)
    }),
)

内容はDefaultHTTPErrorとほとんど同じですが、エラーの型をgoogle.golang.org/genproto/googleapis/rpc/status(*spb.Status)にしています。

こうすることで、クライアント側はgoogle.golang.org/genproto/googleapis/rpc/statusspb*1としてimportし、proto.Unmarshal*spb.Statusに戻せるようになります。

if res.StatusCode >= 400 {
    b, err := ioutil.ReadAll(res.Body)
    if err != nil {
        return nil, errors.New(res.Status)
    }

    v := new(spb.Status)
    if err := proto.Unmarshal(b, v); err != nil {
        return nil, fmt.Errorf("%s: %s", res.Status, b)
    }
    return nil, status.ErrorProto(v)
}

返ってきたエラーはstatus.FromError*status.Statusにすると、CodeMessageDetailsが取り出せるので後続の処理で自由に使えます。

if s, ok := status.FromError(err); ok {
    log.Printf("code: %d, message: %s, details: %v", s.Code(), s.Message(), s.Proto().Details)
}

なお、Detailsは型情報が落ちてしまっているため、内容を正しく出力するにはptypes.UnmarshalAnyで元に戻してあげる必要があります。

例えばDetailserrdetails.BadRequestの場合だと以下のようなコードです。

if s, ok := status.FromError(err); ok {
    var details []string
    for _, d := range s.Proto().Details {
        var m errdetails.BadRequest
        if err := ptypes.UnmarshalAny(d, &m); err == nil {
            details = append(details, m.String())
        }
    }
    log.Printf("code: %d, message: %s, details: %v", s.Code(), s.Message(), details)
}

s.Details()だと[field_violations:<field:"data" description:"invalud" > ]
s.Proto().Detailsだと[type_url:"type.googleapis.com/google.rpc.BadRequest" value:"\n\017\n\004data\022\007invalud" ]になってしまうのが、
ちゃんと[field_violations:<field:"data" description:"invalud" > ]になります。

*1:statusだと"google.golang.org/grpc/status"とかぶるため